Skip to main content

anofox_ml_core/
sparse.rs

1//! Compressed Sparse Row matrix type, intended for high-vocab text
2//! vectorisation output and downstream sparse-friendly estimators.
3//!
4//! Layout matches scipy.sparse.csr_matrix:
5//!
6//! - `indptr`: length `n_rows + 1`. Row `i` occupies the slice
7//!   `data[indptr[i]..indptr[i+1]]` / `indices[indptr[i]..indptr[i+1]]`.
8//! - `indices`: column indices for each non-zero, sorted ascending within a
9//!   row (callers must maintain this invariant for predictable behaviour).
10//! - `data`: parallel values.
11//!
12//! Operations are kept minimal: `nnz`, `density`, `to_dense`, `row_iter`.
13
14use ndarray::Array2;
15
16use crate::float::Float;
17
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
19#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
20pub struct CsrMatrix<F: Float> {
21    pub indptr: Vec<usize>,
22    pub indices: Vec<usize>,
23    pub data: Vec<F>,
24    pub n_rows: usize,
25    pub n_cols: usize,
26}
27
28impl<F: Float> CsrMatrix<F> {
29    /// Build from a list of `(row, col, value)` triplets. Triplets do not
30    /// need to be sorted; this constructor sorts within each row by column.
31    /// Duplicate `(row, col)` entries are summed.
32    pub fn from_triplets(n_rows: usize, n_cols: usize, triplets: Vec<(usize, usize, F)>) -> Self {
33        // Bucket per row.
34        let mut buckets: Vec<Vec<(usize, F)>> = vec![Vec::new(); n_rows];
35        for (r, c, v) in triplets {
36            buckets[r].push((c, v));
37        }
38        // Sort + dedup-by-column within each row.
39        let mut indptr = Vec::with_capacity(n_rows + 1);
40        let mut indices = Vec::new();
41        let mut data = Vec::new();
42        indptr.push(0);
43        for row in buckets.iter_mut() {
44            row.sort_by(|a, b| a.0.cmp(&b.0));
45            // Sum duplicates.
46            let mut last_col: Option<usize> = None;
47            for &(c, v) in row.iter() {
48                if Some(c) == last_col {
49                    let n = data.len();
50                    data[n - 1] = data[n - 1] + v;
51                } else {
52                    indices.push(c);
53                    data.push(v);
54                    last_col = Some(c);
55                }
56            }
57            indptr.push(indices.len());
58        }
59        Self {
60            indptr,
61            indices,
62            data,
63            n_rows,
64            n_cols,
65        }
66    }
67
68    pub fn nnz(&self) -> usize {
69        self.data.len()
70    }
71
72    pub fn density(&self) -> f64 {
73        if self.n_rows == 0 || self.n_cols == 0 {
74            return 0.0;
75        }
76        self.nnz() as f64 / (self.n_rows as f64 * self.n_cols as f64)
77    }
78
79    /// Iterate non-zeros of row `i` as `(col, value)` pairs.
80    pub fn row_iter(&self, i: usize) -> impl Iterator<Item = (usize, F)> + '_ {
81        let start = self.indptr[i];
82        let end = self.indptr[i + 1];
83        self.indices[start..end]
84            .iter()
85            .copied()
86            .zip(self.data[start..end].iter().copied())
87    }
88
89    pub fn to_dense(&self) -> Array2<F> {
90        let mut out = Array2::<F>::zeros((self.n_rows, self.n_cols));
91        for i in 0..self.n_rows {
92            for (c, v) in self.row_iter(i) {
93                out[[i, c]] = v;
94            }
95        }
96        out
97    }
98
99    /// Sparse-dense matrix-vector multiply: `y = A x`. Returns a dense
100    /// vector of length `n_rows`.
101    pub fn matvec(&self, x: &[F]) -> Vec<F> {
102        assert_eq!(x.len(), self.n_cols, "matvec: dimension mismatch");
103        let mut y = vec![F::zero(); self.n_rows];
104        for i in 0..self.n_rows {
105            let mut s = F::zero();
106            for (c, v) in self.row_iter(i) {
107                s = s + v * x[c];
108            }
109            y[i] = s;
110        }
111        y
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_csr_from_triplets_basic() {
121        // 3×4 matrix:
122        // [[1, 0, 0, 2],
123        //  [0, 3, 0, 0],
124        //  [0, 0, 4, 5]]
125        let csr = CsrMatrix::<f64>::from_triplets(
126            3,
127            4,
128            vec![
129                (0, 0, 1.0),
130                (0, 3, 2.0),
131                (1, 1, 3.0),
132                (2, 2, 4.0),
133                (2, 3, 5.0),
134            ],
135        );
136        assert_eq!(csr.nnz(), 5);
137        let dense = csr.to_dense();
138        assert_eq!(dense[[0, 0]], 1.0);
139        assert_eq!(dense[[0, 3]], 2.0);
140        assert_eq!(dense[[1, 1]], 3.0);
141        assert_eq!(dense[[2, 2]], 4.0);
142        assert_eq!(dense[[2, 3]], 5.0);
143        assert_eq!(dense[[1, 0]], 0.0);
144    }
145
146    #[test]
147    fn test_csr_duplicate_triplets_sum() {
148        let csr =
149            CsrMatrix::<f64>::from_triplets(1, 3, vec![(0, 1, 1.0), (0, 1, 2.0), (0, 1, 3.0)]);
150        assert_eq!(csr.nnz(), 1);
151        assert_eq!(csr.to_dense()[[0, 1]], 6.0);
152    }
153
154    #[test]
155    fn test_csr_matvec() {
156        // [[1, 0], [0, 2]] * [3, 4] = [3, 8]
157        let csr = CsrMatrix::<f64>::from_triplets(2, 2, vec![(0, 0, 1.0), (1, 1, 2.0)]);
158        let y = csr.matvec(&[3.0, 4.0]);
159        assert_eq!(y, vec![3.0, 8.0]);
160    }
161
162    #[test]
163    fn test_csr_density() {
164        let csr = CsrMatrix::<f64>::from_triplets(2, 2, vec![(0, 0, 1.0)]);
165        assert!((csr.density() - 0.25).abs() < 1e-12);
166    }
167}