1use crate::core::error::{RedicatError, Result};
4use itertools::Itertools;
5use nalgebra_sparse::ops::serial::spadd_csr_prealloc;
6use nalgebra_sparse::ops::Op;
7use nalgebra_sparse::{CooMatrix, CsrMatrix};
8use rayon::prelude::*;
9use rustc_hash::FxHashMap;
10use smallvec::SmallVec;
11
12pub struct SparseOps;
13
14impl SparseOps {
15 pub fn from_triplets_u32(
17 nrows: usize,
18 ncols: usize,
19 triplets: Vec<(usize, usize, u32)>,
20 ) -> Result<CsrMatrix<u32>> {
21 if nrows == 0 || ncols == 0 {
22 return Ok(CsrMatrix::zeros(nrows, ncols));
23 }
24
25 if triplets.is_empty() {
26 return Ok(CsrMatrix::zeros(nrows, ncols));
27 }
28
29 for &(row, col, _) in &triplets {
31 if row >= nrows || col >= ncols {
32 return Err(RedicatError::InvalidInput(format!(
33 "Index ({}, {}) exceeds matrix dimensions ({}, {})",
34 row, col, nrows, ncols
35 )));
36 }
37 }
38
39 let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
41 triplets.into_iter().multiunzip();
42
43 let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
44 .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
45
46 let csr = CsrMatrix::from(&coo);
48 Ok(csr)
49 }
50
51 pub fn from_triplets(
53 nrows: usize,
54 ncols: usize,
55 triplets: Vec<(usize, usize, u8)>,
56 ) -> Result<CsrMatrix<u8>> {
57 if nrows == 0 || ncols == 0 {
58 return Ok(CsrMatrix::zeros(nrows, ncols));
59 }
60
61 if triplets.is_empty() {
62 return Ok(CsrMatrix::zeros(nrows, ncols));
63 }
64
65 let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
66 triplets.into_iter().multiunzip();
67
68 let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
69 .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
70
71 Ok(CsrMatrix::from(&coo))
72 }
73
74 pub fn add_matrices(a: &CsrMatrix<u32>, b: &CsrMatrix<u32>) -> Result<CsrMatrix<u32>> {
76 if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
77 return Err(RedicatError::DimensionMismatch {
78 expected: format!("{}×{}", a.nrows(), a.ncols()),
79 actual: format!("{}×{}", b.nrows(), b.ncols()),
80 });
81 }
82
83 let pattern = nalgebra_sparse::ops::serial::spadd_pattern(a.pattern(), b.pattern());
85
86 let mut result =
88 CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![0u32; pattern.nnz()])
89 .map_err(|e| {
90 RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e))
91 })?;
92
93 spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(a))
96 .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
97
98 spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(b))
99 .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
100
101 Ok(result)
102 }
103
104 pub fn filter_columns_u32(
106 matrix: &CsrMatrix<u32>,
107 keep_indices: &[usize],
108 ) -> Result<CsrMatrix<u32>> {
109 let nrows = matrix.nrows();
110 let new_ncols = keep_indices.len();
111
112 if new_ncols == 0 {
113 return Ok(CsrMatrix::zeros(nrows, 0));
114 }
115
116 let col_map: FxHashMap<usize, usize> = keep_indices
118 .iter()
119 .enumerate()
120 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
121 .collect();
122
123 let mut new_row_offsets = Vec::with_capacity(nrows + 1);
125 let mut new_col_indices = Vec::new();
126 let mut new_values = Vec::new();
127
128 new_row_offsets.push(0);
129
130 for row_idx in 0..nrows {
131 let row = matrix.row(row_idx);
132
133 for (&old_col, &val) in row.col_indices().iter().zip(row.values()) {
134 if let Some(&new_col) = col_map.get(&old_col) {
135 new_col_indices.push(new_col);
136 new_values.push(val);
137 }
138 }
139
140 new_row_offsets.push(new_col_indices.len());
141 }
142
143 CsrMatrix::try_from_csr_data(
145 nrows,
146 new_ncols,
147 new_row_offsets,
148 new_col_indices,
149 new_values,
150 )
151 .map_err(|e| {
152 RedicatError::SparseMatrix(format!("Failed to create filtered matrix: {:?}", e))
153 })
154 }
155
156 pub fn compute_row_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
158 (0..matrix.nrows())
159 .into_par_iter()
160 .map(|row_idx| {
161 let row = matrix.row(row_idx);
162 row.values()
163 .iter()
164 .fold(0u64, |acc, &val| acc.saturating_add(val as u64))
165 .min(u32::MAX as u64) as u32
166 })
167 .collect()
168 }
169
170 pub fn compute_col_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
172 let ncols = matrix.ncols();
173
174 let chunk_size = std::cmp::max(1, matrix.nrows() / rayon::current_num_threads());
176
177 (0..matrix.nrows())
178 .into_par_iter()
179 .chunks(chunk_size)
180 .map(|chunk| {
181 let mut local_sums = vec![0u64; ncols];
182 for row_idx in chunk {
183 let row = matrix.row(row_idx);
184 for (&col_idx, &val) in row.col_indices().iter().zip(row.values()) {
185 local_sums[col_idx] = local_sums[col_idx].saturating_add(val as u64);
186 }
187 }
188 local_sums
189 })
190 .reduce(
191 || vec![0u64; ncols],
192 |mut acc, local| {
193 for (i, val) in local.into_iter().enumerate() {
194 acc[i] = acc[i].saturating_add(val);
195 }
196 acc
197 },
198 )
199 .into_iter()
200 .map(|sum| (sum.min(u32::MAX as u64)) as u32)
201 .collect()
202 }
203
204 pub fn element_wise_multiply(a: &CsrMatrix<u32>, b: &CsrMatrix<u8>) -> Result<CsrMatrix<u32>> {
211 if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
212 return Err(RedicatError::DimensionMismatch {
213 expected: format!("{}×{}", a.nrows(), a.ncols()),
214 actual: format!("{}×{}", b.nrows(), b.ncols()),
215 });
216 }
217
218 let triplets: Vec<(usize, usize, u32)> = (0..a.nrows())
221 .into_par_iter()
222 .flat_map(|row_idx| {
223 let a_row = a.row(row_idx);
224 let b_row = b.row(row_idx);
225
226 let a_cols = a_row.col_indices();
227 let a_vals = a_row.values();
228 let b_cols = b_row.col_indices();
229 let b_vals = b_row.values();
230
231 let mut result: SmallVec<[(usize, usize, u32); 32]> = SmallVec::new();
234 let mut a_idx = 0;
235 let mut b_idx = 0;
236
237 while a_idx < a_cols.len() && b_idx < b_cols.len() {
238 let a_col = a_cols[a_idx];
239 let b_col = b_cols[b_idx];
240
241 match a_col.cmp(&b_col) {
242 std::cmp::Ordering::Equal => {
243 if b_vals[b_idx] > 0 {
245 result.push((row_idx, a_col, a_vals[a_idx]));
246 }
247 a_idx += 1;
248 b_idx += 1;
249 }
250 std::cmp::Ordering::Less => {
251 a_idx += 1;
253 }
254 std::cmp::Ordering::Greater => {
255 b_idx += 1;
257 }
258 }
259 }
260
261 result.into_vec()
263 })
264 .collect();
265
266 Self::from_triplets_u32(a.nrows(), a.ncols(), triplets)
267 }
268
269 pub fn transpose_u32(matrix: &CsrMatrix<u32>) -> CsrMatrix<u32> {
271 matrix.transpose()
272 }
273
274 pub fn matrix_vector_multiply(matrix: &CsrMatrix<u32>, vector: &[u32]) -> Result<Vec<u32>> {
276 if matrix.ncols() != vector.len() {
277 return Err(RedicatError::DimensionMismatch {
278 expected: format!("vector length = {}", matrix.ncols()),
279 actual: format!("vector length = {}", vector.len()),
280 });
281 }
282
283 let mut result = vec![0u64; matrix.nrows()];
284
285 result
287 .par_iter_mut()
288 .enumerate()
289 .for_each(|(row_idx, result_val)| {
290 let row = matrix.row(row_idx);
291 *result_val = row.col_indices().iter().zip(row.values()).fold(
292 0u64,
293 |acc, (&col_idx, &mat_val)| {
294 acc.saturating_add((mat_val as u64) * (vector[col_idx] as u64))
295 },
296 );
297 });
298
299 Ok(result
300 .into_iter()
301 .map(|val| (val.min(u32::MAX as u64)) as u32)
302 .collect())
303 }
304
305 pub fn get_density_stats(matrix: &CsrMatrix<u32>) -> (f64, usize, usize) {
307 let total_elements = matrix.nrows() * matrix.ncols();
308 let nnz = matrix.nnz();
309 let density = if total_elements > 0 {
310 nnz as f64 / total_elements as f64
311 } else {
312 0.0
313 };
314 (density, nnz, total_elements)
315 }
316}
317
318pub trait SparseMatrixExt<T> {
320 fn apply_threshold(&self, threshold: T) -> CsrMatrix<T>
321 where
322 T: Copy + PartialOrd + Default + nalgebra::Scalar;
323}
324
325impl SparseMatrixExt<u32> for CsrMatrix<u32> {
326 fn apply_threshold(&self, threshold: u32) -> CsrMatrix<u32> {
328 let triplets: Vec<(usize, usize, u32)> = self
329 .triplet_iter()
330 .filter_map(|(row, col, &val)| {
331 if val >= threshold {
332 Some((row, col, val))
333 } else {
334 None
335 }
336 })
337 .collect();
338
339 SparseOps::from_triplets_u32(self.nrows(), self.ncols(), triplets)
340 .unwrap_or_else(|_| CsrMatrix::zeros(self.nrows(), self.ncols()))
341 }
342}