Skip to main content

cjc_runtime/
sparse.rs

1use std::collections::BTreeMap;
2
3use cjc_repro::kahan_sum_f64;
4
5use crate::accumulator::binned_sum_f64;
6use crate::error::RuntimeError;
7use crate::tensor::Tensor;
8
9// ---------------------------------------------------------------------------
10// 4. Sparse Tensor Representations (CSR + COO)
11// ---------------------------------------------------------------------------
12
13/// Compressed Sparse Row (CSR) matrix representation.
14#[derive(Debug, Clone)]
15pub struct SparseCsr {
16    pub values: Vec<f64>,
17    pub col_indices: Vec<usize>,
18    pub row_offsets: Vec<usize>, // length = nrows + 1
19    pub nrows: usize,
20    pub ncols: usize,
21}
22
23impl SparseCsr {
24    /// Number of non-zero elements.
25    pub fn nnz(&self) -> usize {
26        self.values.len()
27    }
28
29    /// Access element at (row, col). Returns 0.0 for zero entries.
30    pub fn get(&self, row: usize, col: usize) -> f64 {
31        if row >= self.nrows || col >= self.ncols {
32            return 0.0;
33        }
34        let start = self.row_offsets[row];
35        let end = self.row_offsets[row + 1];
36        for idx in start..end {
37            if self.col_indices[idx] == col {
38                return self.values[idx];
39            }
40        }
41        0.0
42    }
43
44    /// Sparse matrix-vector multiplication: y = A * x.
45    pub fn matvec(&self, x: &[f64]) -> Result<Vec<f64>, RuntimeError> {
46        if x.len() != self.ncols {
47            return Err(RuntimeError::DimensionMismatch {
48                expected: self.ncols,
49                got: x.len(),
50            });
51        }
52        let mut y = vec![0.0f64; self.nrows];
53        for row in 0..self.nrows {
54            let start = self.row_offsets[row];
55            let end = self.row_offsets[row + 1];
56            let products: Vec<f64> = (start..end)
57                .map(|idx| self.values[idx] * x[self.col_indices[idx]])
58                .collect();
59            y[row] = kahan_sum_f64(&products);
60        }
61        Ok(y)
62    }
63
64    /// Convert to dense Tensor.
65    pub fn to_dense(&self) -> Tensor {
66        let mut data = vec![0.0f64; self.nrows * self.ncols];
67        for row in 0..self.nrows {
68            let start = self.row_offsets[row];
69            let end = self.row_offsets[row + 1];
70            for idx in start..end {
71                data[row * self.ncols + self.col_indices[idx]] = self.values[idx];
72            }
73        }
74        Tensor::from_vec(data, &[self.nrows, self.ncols]).unwrap()
75    }
76
77    /// Construct CSR from COO data.
78    pub fn from_coo(coo: &SparseCoo) -> Self {
79        // Sort by row, then by column
80        let nnz = coo.values.len();
81        let mut order: Vec<usize> = (0..nnz).collect();
82        order.sort_by_key(|&i| (coo.row_indices[i], coo.col_indices[i]));
83
84        let mut values = Vec::with_capacity(nnz);
85        let mut col_indices = Vec::with_capacity(nnz);
86        let mut row_offsets = vec![0usize; coo.nrows + 1];
87
88        for &i in &order {
89            values.push(coo.values[i]);
90            col_indices.push(coo.col_indices[i]);
91            row_offsets[coo.row_indices[i] + 1] += 1;
92        }
93
94        // Cumulative sum for row_offsets
95        for i in 1..=coo.nrows {
96            row_offsets[i] += row_offsets[i - 1];
97        }
98
99        SparseCsr {
100            values,
101            col_indices,
102            row_offsets,
103            nrows: coo.nrows,
104            ncols: coo.ncols,
105        }
106    }
107}
108
109/// Coordinate (COO) sparse matrix representation.
110#[derive(Debug, Clone)]
111pub struct SparseCoo {
112    pub values: Vec<f64>,
113    pub row_indices: Vec<usize>,
114    pub col_indices: Vec<usize>,
115    pub nrows: usize,
116    pub ncols: usize,
117}
118
119impl SparseCoo {
120    pub fn new(
121        values: Vec<f64>,
122        row_indices: Vec<usize>,
123        col_indices: Vec<usize>,
124        nrows: usize,
125        ncols: usize,
126    ) -> Self {
127        SparseCoo {
128            values,
129            row_indices,
130            col_indices,
131            nrows,
132            ncols,
133        }
134    }
135
136    pub fn nnz(&self) -> usize {
137        self.values.len()
138    }
139
140    pub fn to_csr(&self) -> SparseCsr {
141        SparseCsr::from_coo(self)
142    }
143
144    pub fn sum(&self) -> f64 {
145        kahan_sum_f64(&self.values)
146    }
147}
148
149// ---------------------------------------------------------------------------
150// Sparse Arithmetic Operations
151// ---------------------------------------------------------------------------
152
153/// Helper: merge two sorted CSR rows element-wise using a combiner function.
154/// Returns (values, col_indices) for the merged row, dropping exact zeros.
155fn merge_rows(
156    a_vals: &[f64],
157    a_cols: &[usize],
158    b_vals: &[f64],
159    b_cols: &[usize],
160    combine: fn(f64, f64) -> f64,
161    default_a: f64,
162    default_b: f64,
163) -> (Vec<f64>, Vec<usize>) {
164    let mut values = Vec::new();
165    let mut cols = Vec::new();
166    let mut ia = 0;
167    let mut ib = 0;
168
169    while ia < a_cols.len() && ib < b_cols.len() {
170        match a_cols[ia].cmp(&b_cols[ib]) {
171            std::cmp::Ordering::Less => {
172                let v = combine(a_vals[ia], default_b);
173                if v != 0.0 {
174                    values.push(v);
175                    cols.push(a_cols[ia]);
176                }
177                ia += 1;
178            }
179            std::cmp::Ordering::Greater => {
180                let v = combine(default_a, b_vals[ib]);
181                if v != 0.0 {
182                    values.push(v);
183                    cols.push(b_cols[ib]);
184                }
185                ib += 1;
186            }
187            std::cmp::Ordering::Equal => {
188                let v = combine(a_vals[ia], b_vals[ib]);
189                if v != 0.0 {
190                    values.push(v);
191                    cols.push(a_cols[ia]);
192                }
193                ia += 1;
194                ib += 1;
195            }
196        }
197    }
198    while ia < a_cols.len() {
199        let v = combine(a_vals[ia], default_b);
200        if v != 0.0 {
201            values.push(v);
202            cols.push(a_cols[ia]);
203        }
204        ia += 1;
205    }
206    while ib < b_cols.len() {
207        let v = combine(default_a, b_vals[ib]);
208        if v != 0.0 {
209            values.push(v);
210            cols.push(b_cols[ib]);
211        }
212        ib += 1;
213    }
214    (values, cols)
215}
216
217/// Apply an element-wise binary operation on two CSR matrices of the same dimensions.
218fn sparse_binop(
219    a: &SparseCsr,
220    b: &SparseCsr,
221    combine: fn(f64, f64) -> f64,
222    default_a: f64,
223    default_b: f64,
224    op_name: &str,
225) -> Result<SparseCsr, String> {
226    if a.nrows != b.nrows || a.ncols != b.ncols {
227        return Err(format!(
228            "sparse_{}: dimension mismatch: ({}, {}) vs ({}, {})",
229            op_name, a.nrows, a.ncols, b.nrows, b.ncols
230        ));
231    }
232
233    let mut values = Vec::new();
234    let mut col_indices = Vec::new();
235    let mut row_offsets = Vec::with_capacity(a.nrows + 1);
236    row_offsets.push(0);
237
238    for row in 0..a.nrows {
239        let a_start = a.row_offsets[row];
240        let a_end = a.row_offsets[row + 1];
241        let b_start = b.row_offsets[row];
242        let b_end = b.row_offsets[row + 1];
243
244        let (rv, rc) = merge_rows(
245            &a.values[a_start..a_end],
246            &a.col_indices[a_start..a_end],
247            &b.values[b_start..b_end],
248            &b.col_indices[b_start..b_end],
249            combine,
250            default_a,
251            default_b,
252        );
253        values.extend_from_slice(&rv);
254        col_indices.extend_from_slice(&rc);
255        row_offsets.push(values.len());
256    }
257
258    Ok(SparseCsr {
259        values,
260        col_indices,
261        row_offsets,
262        nrows: a.nrows,
263        ncols: a.ncols,
264    })
265}
266
267/// Element-wise addition of two CSR matrices (same dimensions).
268pub fn sparse_add(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
269    sparse_binop(a, b, |x, y| x + y, 0.0, 0.0, "add")
270}
271
272/// Element-wise subtraction of two CSR matrices (same dimensions).
273pub fn sparse_sub(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
274    sparse_binop(a, b, |x, y| x - y, 0.0, 0.0, "sub")
275}
276
277/// Element-wise multiplication (Hadamard product) of two CSR matrices.
278/// Only positions where BOTH matrices have non-zeros produce non-zeros.
279pub fn sparse_mul(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
280    if a.nrows != b.nrows || a.ncols != b.ncols {
281        return Err(format!(
282            "sparse_mul: dimension mismatch: ({}, {}) vs ({}, {})",
283            a.nrows, a.ncols, b.nrows, b.ncols
284        ));
285    }
286
287    let mut values = Vec::new();
288    let mut col_indices = Vec::new();
289    let mut row_offsets = Vec::with_capacity(a.nrows + 1);
290    row_offsets.push(0);
291
292    for row in 0..a.nrows {
293        let a_start = a.row_offsets[row];
294        let a_end = a.row_offsets[row + 1];
295        let b_start = b.row_offsets[row];
296        let b_end = b.row_offsets[row + 1];
297
298        let mut ia = a_start;
299        let mut ib = b_start;
300
301        // Only emit where both have entries (Hadamard)
302        while ia < a_end && ib < b_end {
303            match a.col_indices[ia].cmp(&b.col_indices[ib]) {
304                std::cmp::Ordering::Less => ia += 1,
305                std::cmp::Ordering::Greater => ib += 1,
306                std::cmp::Ordering::Equal => {
307                    let v = a.values[ia] * b.values[ib];
308                    if v != 0.0 {
309                        values.push(v);
310                        col_indices.push(a.col_indices[ia]);
311                    }
312                    ia += 1;
313                    ib += 1;
314                }
315            }
316        }
317        row_offsets.push(values.len());
318    }
319
320    Ok(SparseCsr {
321        values,
322        col_indices,
323        row_offsets,
324        nrows: a.nrows,
325        ncols: a.ncols,
326    })
327}
328
329/// Sparse matrix-matrix multiplication (SpGEMM): C = A * B.
330/// Uses row-wise accumulation with BTreeMap for deterministic column ordering.
331/// All floating-point reductions use binned summation.
332pub fn sparse_matmul(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
333    if a.ncols != b.nrows {
334        return Err(format!(
335            "sparse_matmul: inner dimension mismatch: A is ({}, {}), B is ({}, {})",
336            a.nrows, a.ncols, b.nrows, b.ncols
337        ));
338    }
339
340    let mut values = Vec::new();
341    let mut col_indices = Vec::new();
342    let mut row_offsets = Vec::with_capacity(a.nrows + 1);
343    row_offsets.push(0);
344
345    for row in 0..a.nrows {
346        // Accumulate contributions into a BTreeMap for deterministic column order.
347        let mut accum: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
348
349        let a_start = a.row_offsets[row];
350        let a_end = a.row_offsets[row + 1];
351
352        for a_idx in a_start..a_end {
353            let k = a.col_indices[a_idx];
354            let a_val = a.values[a_idx];
355
356            let b_start = b.row_offsets[k];
357            let b_end = b.row_offsets[k + 1];
358
359            for b_idx in b_start..b_end {
360                let j = b.col_indices[b_idx];
361                accum.entry(j).or_default().push(a_val * b.values[b_idx]);
362            }
363        }
364
365        // BTreeMap iterates in sorted column order (deterministic)
366        for (col, terms) in &accum {
367            let v = binned_sum_f64(&terms);
368            if v != 0.0 {
369                col_indices.push(*col);
370                values.push(v);
371            }
372        }
373        row_offsets.push(values.len());
374    }
375
376    Ok(SparseCsr {
377        values,
378        col_indices,
379        row_offsets,
380        nrows: a.nrows,
381        ncols: b.ncols,
382    })
383}
384
385/// Scalar multiplication: every non-zero element is multiplied by `s`.
386pub fn sparse_scalar_mul(a: &SparseCsr, s: f64) -> SparseCsr {
387    let values: Vec<f64> = a.values.iter().map(|&v| v * s).collect();
388    SparseCsr {
389        values,
390        col_indices: a.col_indices.clone(),
391        row_offsets: a.row_offsets.clone(),
392        nrows: a.nrows,
393        ncols: a.ncols,
394    }
395}
396
397/// Transpose a CSR matrix. Returns a new CSR where rows and columns are swapped.
398pub fn sparse_transpose(a: &SparseCsr) -> SparseCsr {
399    // Build COO in (col, row) order, then convert to CSR of transposed shape.
400    let mut row_counts = vec![0usize; a.ncols + 1];
401
402    // Count entries per column of A (= per row of A^T)
403    for &c in &a.col_indices {
404        row_counts[c + 1] += 1;
405    }
406    // Prefix sum
407    for i in 1..=a.ncols {
408        row_counts[i] += row_counts[i - 1];
409    }
410
411    let nnz = a.values.len();
412    let mut new_values = vec![0.0f64; nnz];
413    let mut new_col_indices = vec![0usize; nnz];
414    let mut cursor = row_counts.clone();
415
416    for row in 0..a.nrows {
417        let start = a.row_offsets[row];
418        let end = a.row_offsets[row + 1];
419        for idx in start..end {
420            let col = a.col_indices[idx];
421            let dest = cursor[col];
422            new_values[dest] = a.values[idx];
423            new_col_indices[dest] = row;
424            cursor[col] += 1;
425        }
426    }
427
428    SparseCsr {
429        values: new_values,
430        col_indices: new_col_indices,
431        row_offsets: row_counts,
432        nrows: a.ncols,
433        ncols: a.nrows,
434    }
435}
436
437// ---------------------------------------------------------------------------
438// Tests
439// ---------------------------------------------------------------------------
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    /// Helper: build a small CSR matrix from dense data.
446    fn csr_from_dense(data: &[f64], nrows: usize, ncols: usize) -> SparseCsr {
447        let mut values = Vec::new();
448        let mut col_indices = Vec::new();
449        let mut row_offsets = vec![0usize];
450
451        for r in 0..nrows {
452            for c in 0..ncols {
453                let v = data[r * ncols + c];
454                if v != 0.0 {
455                    values.push(v);
456                    col_indices.push(c);
457                }
458            }
459            row_offsets.push(values.len());
460        }
461
462        SparseCsr { values, col_indices, row_offsets, nrows, ncols }
463    }
464
465    // -- sparse_add --
466
467    #[test]
468    fn test_sparse_add_basic() {
469        let a = csr_from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0, 4.0, 5.0], 3, 3);
470        let b = csr_from_dense(&[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0], 3, 3);
471        let c = sparse_add(&a, &b).unwrap();
472        // Dense result: [1,1,2, 2,3,3, 4,4,10]
473        for r in 0..3 {
474            for col in 0..3 {
475                let expected = a.get(r, col) + b.get(r, col);
476                assert_eq!(c.get(r, col), expected, "mismatch at ({}, {})", r, col);
477            }
478        }
479    }
480
481    #[test]
482    fn test_sparse_add_a_plus_a_eq_2a() {
483        let a = csr_from_dense(&[1.0, 2.0, 0.0, 3.0], 2, 2);
484        let sum = sparse_add(&a, &a).unwrap();
485        let doubled = sparse_scalar_mul(&a, 2.0);
486        for r in 0..2 {
487            for c in 0..2 {
488                assert_eq!(sum.get(r, c), doubled.get(r, c));
489            }
490        }
491    }
492
493    #[test]
494    fn test_sparse_add_dimension_mismatch() {
495        let a = csr_from_dense(&[1.0, 2.0], 1, 2);
496        let b = csr_from_dense(&[1.0, 2.0, 3.0], 1, 3);
497        assert!(sparse_add(&a, &b).is_err());
498    }
499
500    // -- sparse_sub --
501
502    #[test]
503    fn test_sparse_sub_basic() {
504        let a = csr_from_dense(&[5.0, 3.0, 0.0, 1.0], 2, 2);
505        let b = csr_from_dense(&[2.0, 3.0, 1.0, 0.0], 2, 2);
506        let c = sparse_sub(&a, &b).unwrap();
507        assert_eq!(c.get(0, 0), 3.0);
508        assert_eq!(c.get(0, 1), 0.0); // 3 - 3 = 0, should be dropped
509        assert_eq!(c.get(1, 0), -1.0);
510        assert_eq!(c.get(1, 1), 1.0);
511    }
512
513    #[test]
514    fn test_sparse_sub_self_is_zero() {
515        let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
516        let c = sparse_sub(&a, &a).unwrap();
517        assert_eq!(c.nnz(), 0);
518    }
519
520    // -- sparse_mul (Hadamard) --
521
522    #[test]
523    fn test_sparse_mul_hadamard() {
524        let a = csr_from_dense(&[1.0, 0.0, 3.0, 4.0], 2, 2);
525        let b = csr_from_dense(&[2.0, 5.0, 0.0, 3.0], 2, 2);
526        let c = sparse_mul(&a, &b).unwrap();
527        assert_eq!(c.get(0, 0), 2.0);  // 1*2
528        assert_eq!(c.get(0, 1), 0.0);  // one is zero
529        assert_eq!(c.get(1, 0), 0.0);  // one is zero
530        assert_eq!(c.get(1, 1), 12.0); // 4*3
531    }
532
533    // -- sparse_matmul --
534
535    #[test]
536    fn test_sparse_matmul_identity() {
537        // A * I = A
538        let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
539        let eye = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
540        let c = sparse_matmul(&a, &eye).unwrap();
541        for r in 0..2 {
542            for col in 0..2 {
543                assert_eq!(c.get(r, col), a.get(r, col));
544            }
545        }
546    }
547
548    #[test]
549    fn test_sparse_matmul_vs_dense() {
550        // Compare sparse matmul against dense result for a small case
551        let a_data = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0];
552        let b_data = [5.0, 0.0, 6.0, 7.0, 0.0, 8.0];
553        let a = csr_from_dense(&a_data, 2, 3);
554        let b = csr_from_dense(&b_data, 3, 2);
555
556        let c = sparse_matmul(&a, &b).unwrap();
557
558        // Dense: A(2x3) * B(3x2)
559        // C[0,0] = 1*5 + 2*6 + 0*0 = 17
560        // C[0,1] = 1*0 + 2*7 + 0*8 = 14
561        // C[1,0] = 0*5 + 3*6 + 4*0 = 18
562        // C[1,1] = 0*0 + 3*7 + 4*8 = 53
563        assert_eq!(c.get(0, 0), 17.0);
564        assert_eq!(c.get(0, 1), 14.0);
565        assert_eq!(c.get(1, 0), 18.0);
566        assert_eq!(c.get(1, 1), 53.0);
567    }
568
569    #[test]
570    fn test_sparse_matmul_dimension_mismatch() {
571        let a = csr_from_dense(&[1.0, 2.0], 1, 2);
572        let b = csr_from_dense(&[1.0, 2.0, 3.0], 1, 3);
573        assert!(sparse_matmul(&a, &b).is_err());
574    }
575
576    // -- sparse_scalar_mul --
577
578    #[test]
579    fn test_sparse_scalar_mul_basic() {
580        let a = csr_from_dense(&[2.0, 0.0, 0.0, 4.0], 2, 2);
581        let c = sparse_scalar_mul(&a, 3.0);
582        assert_eq!(c.get(0, 0), 6.0);
583        assert_eq!(c.get(1, 1), 12.0);
584        assert_eq!(c.nnz(), 2);
585    }
586
587    // -- sparse_transpose --
588
589    #[test]
590    fn test_sparse_transpose_square() {
591        let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
592        let at = sparse_transpose(&a);
593        assert_eq!(at.get(0, 0), 1.0);
594        assert_eq!(at.get(0, 1), 3.0);
595        assert_eq!(at.get(1, 0), 2.0);
596        assert_eq!(at.get(1, 1), 4.0);
597    }
598
599    #[test]
600    fn test_sparse_transpose_rect() {
601        let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
602        let at = sparse_transpose(&a);
603        assert_eq!(at.nrows, 3);
604        assert_eq!(at.ncols, 2);
605        for r in 0..2 {
606            for c in 0..3 {
607                assert_eq!(at.get(c, r), a.get(r, c), "mismatch at transpose({}, {})", c, r);
608            }
609        }
610    }
611
612    #[test]
613    fn test_sparse_transpose_double_is_identity() {
614        let a = csr_from_dense(&[1.0, 0.0, 2.0, 3.0, 0.0, 4.0], 2, 3);
615        let att = sparse_transpose(&sparse_transpose(&a));
616        assert_eq!(att.nrows, a.nrows);
617        assert_eq!(att.ncols, a.ncols);
618        for r in 0..a.nrows {
619            for c in 0..a.ncols {
620                assert_eq!(att.get(r, c), a.get(r, c));
621            }
622        }
623    }
624
625    // -- Determinism --
626
627    #[test]
628    fn test_sparse_matmul_determinism() {
629        let a = csr_from_dense(&[1.0, 2.0, 0.0, 0.0, 3.0, 4.0], 2, 3);
630        let b = csr_from_dense(&[5.0, 0.0, 6.0, 7.0, 0.0, 8.0], 3, 2);
631
632        let c1 = sparse_matmul(&a, &b).unwrap();
633        let c2 = sparse_matmul(&a, &b).unwrap();
634
635        assert_eq!(c1.values, c2.values);
636        assert_eq!(c1.col_indices, c2.col_indices);
637        assert_eq!(c1.row_offsets, c2.row_offsets);
638    }
639
640    #[test]
641    fn test_sparse_add_determinism() {
642        let a = csr_from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0, 4.0, 5.0], 3, 3);
643        let b = csr_from_dense(&[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0], 3, 3);
644
645        let c1 = sparse_add(&a, &b).unwrap();
646        let c2 = sparse_add(&a, &b).unwrap();
647
648        assert_eq!(c1.values, c2.values);
649        assert_eq!(c1.col_indices, c2.col_indices);
650    }
651}
652