1use crate::csr::CsrMatrix;
14use crate::error::SparseError;
15
16pub fn spgemm(a: &CsrMatrix<f32>, b: &CsrMatrix<f32>) -> Result<CsrMatrix<f32>, SparseError> {
25 if a.cols() != b.rows() {
26 return Err(SparseError::SpMVDimensionMismatch {
27 matrix_cols: a.cols(),
28 x_len: b.rows(),
29 });
30 }
31
32 let m = a.rows();
33 let n = b.cols();
34
35 let mut c_offsets = Vec::with_capacity(m + 1);
36 let mut c_col_indices = Vec::new();
37 let mut c_values = Vec::new();
38
39 let mut work = vec![0.0_f32; n];
41 let mut marker = vec![false; n];
42 let mut col_list = Vec::new();
43
44 c_offsets.push(0u32);
45
46 for i in 0..m {
47 accumulate_row(a, b, i, &mut work, &mut marker, &mut col_list);
48 emit_row(
49 &mut c_col_indices,
50 &mut c_values,
51 &mut c_offsets,
52 &mut work,
53 &mut marker,
54 &mut col_list,
55 );
56 }
57
58 CsrMatrix::new(m, n, c_offsets, c_col_indices, c_values)
59}
60
61fn accumulate_row(
63 a: &CsrMatrix<f32>,
64 b: &CsrMatrix<f32>,
65 i: usize,
66 work: &mut [f32],
67 marker: &mut [bool],
68 col_list: &mut Vec<usize>,
69) {
70 let a_off = a.offsets();
71 let a_cols = a.col_indices();
72 let a_vals = a.values();
73 let b_off = b.offsets();
74 let b_cols = b.col_indices();
75 let b_vals = b.values();
76
77 let a_start = a_off[i] as usize;
78 let a_end = a_off[i + 1] as usize;
79
80 for a_idx in a_start..a_end {
81 let k = a_cols[a_idx] as usize;
82 let a_val = a_vals[a_idx];
83
84 let b_start = b_off[k] as usize;
85 let b_end = b_off[k + 1] as usize;
86
87 for b_idx in b_start..b_end {
88 let j = b_cols[b_idx] as usize;
89 if !marker[j] {
90 marker[j] = true;
91 col_list.push(j);
92 }
93 work[j] += a_val * b_vals[b_idx];
94 }
95 }
96}
97
98fn emit_row(
100 col_indices: &mut Vec<u32>,
101 values: &mut Vec<f32>,
102 offsets: &mut Vec<u32>,
103 work: &mut [f32],
104 marker: &mut [bool],
105 col_list: &mut Vec<usize>,
106) {
107 col_list.sort_unstable();
108
109 for &j in col_list.iter() {
110 let val = work[j];
111 if val.abs() > f32::EPSILON {
112 col_indices.push(j as u32);
113 values.push(val);
114 }
115 }
116
117 for &j in col_list.iter() {
119 work[j] = 0.0;
120 marker[j] = false;
121 }
122 col_list.clear();
123
124 let nnz = col_indices.len() as u32;
125 offsets.push(nnz);
126}