Skip to main content

trueno_sparse/
spgemm.rs

1//! Sparse General Matrix-Matrix Multiply (SpGEMM).
2//!
3//! # Contract: sparse-spgemm-v1.yaml
4//!
5//! Computes C = A * B where A and B are sparse CSR matrices.
6//! Uses Gustavson's algorithm (row-by-row with hash accumulator).
7//!
8//! ## Proof obligations
9//! - Associativity: (AB)C = A(BC) within tolerance
10//! - Identity: AI = A, IA = A
11//! - Zero: A×0 = 0
12
13use crate::csr::CsrMatrix;
14use crate::error::SparseError;
15
16/// Sparse matrix-matrix multiply: C = A * B (both CSR).
17///
18/// Uses Gustavson's algorithm: for each row of A, scatter-gather
19/// into a dense workspace, then compress into CSR.
20///
21/// # Errors
22///
23/// Returns error if A.cols() != B.rows().
24pub fn spgemm(a: &CsrMatrix<f32>, b: &CsrMatrix<f32>) -> Result<CsrMatrix<f32>, SparseError> {
25    if a.cols() != b.rows() {
26        return Err(SparseError::SpMVDimensionMismatch {
27            matrix_cols: a.cols(),
28            x_len: b.rows(),
29        });
30    }
31
32    let m = a.rows();
33    let n = b.cols();
34
35    let mut c_offsets = Vec::with_capacity(m + 1);
36    let mut c_col_indices = Vec::new();
37    let mut c_values = Vec::new();
38
39    // Dense workspace for accumulating one row of C
40    let mut work = vec![0.0_f32; n];
41    let mut marker = vec![false; n];
42    let mut col_list = Vec::new();
43
44    c_offsets.push(0u32);
45
46    for i in 0..m {
47        accumulate_row(a, b, i, &mut work, &mut marker, &mut col_list);
48        emit_row(
49            &mut c_col_indices,
50            &mut c_values,
51            &mut c_offsets,
52            &mut work,
53            &mut marker,
54            &mut col_list,
55        );
56    }
57
58    CsrMatrix::new(m, n, c_offsets, c_col_indices, c_values)
59}
60
61/// Accumulate row i of C = A * B into workspace.
62fn accumulate_row(
63    a: &CsrMatrix<f32>,
64    b: &CsrMatrix<f32>,
65    i: usize,
66    work: &mut [f32],
67    marker: &mut [bool],
68    col_list: &mut Vec<usize>,
69) {
70    let a_off = a.offsets();
71    let a_cols = a.col_indices();
72    let a_vals = a.values();
73    let b_off = b.offsets();
74    let b_cols = b.col_indices();
75    let b_vals = b.values();
76
77    let a_start = a_off[i] as usize;
78    let a_end = a_off[i + 1] as usize;
79
80    for a_idx in a_start..a_end {
81        let k = a_cols[a_idx] as usize;
82        let a_val = a_vals[a_idx];
83
84        let b_start = b_off[k] as usize;
85        let b_end = b_off[k + 1] as usize;
86
87        for b_idx in b_start..b_end {
88            let j = b_cols[b_idx] as usize;
89            if !marker[j] {
90                marker[j] = true;
91                col_list.push(j);
92            }
93            work[j] += a_val * b_vals[b_idx];
94        }
95    }
96}
97
98/// Emit accumulated row into CSR arrays and reset workspace.
99fn emit_row(
100    col_indices: &mut Vec<u32>,
101    values: &mut Vec<f32>,
102    offsets: &mut Vec<u32>,
103    work: &mut [f32],
104    marker: &mut [bool],
105    col_list: &mut Vec<usize>,
106) {
107    col_list.sort_unstable();
108
109    for &j in col_list.iter() {
110        let val = work[j];
111        if val.abs() > f32::EPSILON {
112            col_indices.push(j as u32);
113            values.push(val);
114        }
115    }
116
117    // Reset workspace for next row
118    for &j in col_list.iter() {
119        work[j] = 0.0;
120        marker[j] = false;
121    }
122    col_list.clear();
123
124    let nnz = col_indices.len() as u32;
125    offsets.push(nnz);
126}