1use crate::core::error::{RedicatError, Result};
4use itertools::Itertools;
5use nalgebra_sparse::ops::serial::spadd_csr_prealloc;
6use nalgebra_sparse::ops::Op;
7use nalgebra_sparse::{CooMatrix, CsrMatrix};
8use rayon::prelude::*;
9use rustc_hash::FxHashMap;
10use smallvec::SmallVec;
11use std::io::{Read, Write};
12use std::path::Path;
13
14pub struct SparseOps;
15
16impl SparseOps {
17 pub fn from_triplets_u32(
19 nrows: usize,
20 ncols: usize,
21 triplets: Vec<(usize, usize, u32)>,
22 ) -> Result<CsrMatrix<u32>> {
23 if nrows == 0 || ncols == 0 {
24 return Ok(CsrMatrix::zeros(nrows, ncols));
25 }
26
27 if triplets.is_empty() {
28 return Ok(CsrMatrix::zeros(nrows, ncols));
29 }
30
31 for &(row, col, _) in &triplets {
33 if row >= nrows || col >= ncols {
34 return Err(RedicatError::InvalidInput(format!(
35 "Index ({}, {}) exceeds matrix dimensions ({}, {})",
36 row, col, nrows, ncols
37 )));
38 }
39 }
40
41 let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
43 triplets.into_iter().multiunzip();
44
45 let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
46 .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
47
48 let csr = CsrMatrix::from(&coo);
50 Ok(csr)
51 }
52
53 pub fn from_triplets(
55 nrows: usize,
56 ncols: usize,
57 triplets: Vec<(usize, usize, u8)>,
58 ) -> Result<CsrMatrix<u8>> {
59 if nrows == 0 || ncols == 0 {
60 return Ok(CsrMatrix::zeros(nrows, ncols));
61 }
62
63 if triplets.is_empty() {
64 return Ok(CsrMatrix::zeros(nrows, ncols));
65 }
66
67 let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
68 triplets.into_iter().multiunzip();
69
70 let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
71 .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
72
73 Ok(CsrMatrix::from(&coo))
74 }
75
76 pub fn add_matrices(a: &CsrMatrix<u32>, b: &CsrMatrix<u32>) -> Result<CsrMatrix<u32>> {
78 if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
79 return Err(RedicatError::DimensionMismatch {
80 expected: format!("{}×{}", a.nrows(), a.ncols()),
81 actual: format!("{}×{}", b.nrows(), b.ncols()),
82 });
83 }
84
85 let pattern = nalgebra_sparse::ops::serial::spadd_pattern(a.pattern(), b.pattern());
87
88 let mut result =
90 CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![0u32; pattern.nnz()])
91 .map_err(|e| {
92 RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e))
93 })?;
94
95 spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(a))
98 .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
99
100 spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(b))
101 .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
102
103 Ok(result)
104 }
105
106 pub fn parallel_sum_matrices(matrices: &[&CsrMatrix<u32>]) -> Result<CsrMatrix<u32>> {
120 if matrices.is_empty() {
121 return Err(RedicatError::EmptyData("No matrices to sum".to_string()));
122 }
123
124 if matrices.len() == 1 {
125 return Ok(matrices[0].clone());
126 }
127
128 let (nrows, ncols) = (matrices[0].nrows(), matrices[0].ncols());
130 for matrix in matrices.iter().skip(1) {
131 if matrix.nrows() != nrows || matrix.ncols() != ncols {
132 return Err(RedicatError::DimensionMismatch {
133 expected: format!("{}×{}", nrows, ncols),
134 actual: format!("{}×{}", matrix.nrows(), matrix.ncols()),
135 });
136 }
137 }
138
139 let mut union_pattern = matrices[0].pattern().clone();
141 for matrix in matrices.iter().skip(1) {
142 union_pattern =
143 nalgebra_sparse::ops::serial::spadd_pattern(&union_pattern, matrix.pattern());
144 }
145
146 let nnz = union_pattern.nnz();
148 let mut result =
149 CsrMatrix::try_from_pattern_and_values(union_pattern, vec![0u32; nnz]).map_err(
150 |e| RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e)),
151 )?;
152
153 for matrix in matrices {
155 spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(*matrix)).map_err(|e| {
156 RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e))
157 })?;
158 }
159
160 Ok(result)
161 }
162
163 pub fn filter_columns_u32(
165 matrix: &CsrMatrix<u32>,
166 keep_indices: &[usize],
167 ) -> Result<CsrMatrix<u32>> {
168 let nrows = matrix.nrows();
169 let new_ncols = keep_indices.len();
170
171 if new_ncols == 0 {
172 return Ok(CsrMatrix::zeros(nrows, 0));
173 }
174
175 let col_map: FxHashMap<usize, usize> = keep_indices
177 .iter()
178 .enumerate()
179 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
180 .collect();
181
182 let mut new_row_offsets = Vec::with_capacity(nrows + 1);
184 let mut new_col_indices = Vec::new();
185 let mut new_values = Vec::new();
186
187 new_row_offsets.push(0);
188
189 for row_idx in 0..nrows {
190 let row = matrix.row(row_idx);
191
192 for (&old_col, &val) in row.col_indices().iter().zip(row.values()) {
193 if let Some(&new_col) = col_map.get(&old_col) {
194 new_col_indices.push(new_col);
195 new_values.push(val);
196 }
197 }
198
199 new_row_offsets.push(new_col_indices.len());
200 }
201
202 CsrMatrix::try_from_csr_data(
204 nrows,
205 new_ncols,
206 new_row_offsets,
207 new_col_indices,
208 new_values,
209 )
210 .map_err(|e| {
211 RedicatError::SparseMatrix(format!("Failed to create filtered matrix: {:?}", e))
212 })
213 }
214
215 pub fn compute_row_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
217 (0..matrix.nrows())
218 .into_par_iter()
219 .map(|row_idx| {
220 let row = matrix.row(row_idx);
221 row.values()
222 .iter()
223 .fold(0u64, |acc, &val| acc.saturating_add(val as u64))
224 .min(u32::MAX as u64) as u32
225 })
226 .collect()
227 }
228
229 pub fn compute_masked_row_sums(matrix: &CsrMatrix<u32>, mask: &[bool]) -> Vec<u32> {
231 let mask_len = mask.len();
232 if matrix.ncols() != mask_len {
233 return vec![0; matrix.nrows()];
234 }
235
236 (0..matrix.nrows())
237 .into_par_iter()
238 .map(|row_idx| {
239 let row = matrix.row(row_idx);
240 row.col_indices()
241 .iter()
242 .zip(row.values())
243 .fold(0u64, |acc, (&col_idx, &val)| {
244 if mask[col_idx] {
245 acc.saturating_add(val as u64)
246 } else {
247 acc
248 }
249 })
250 .min(u32::MAX as u64) as u32
251 })
252 .collect()
253 }
254
255 pub fn compute_col_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
257 let ncols = matrix.ncols();
258
259 let chunk_size = std::cmp::max(1, matrix.nrows() / rayon::current_num_threads());
261
262 (0..matrix.nrows())
263 .into_par_iter()
264 .chunks(chunk_size)
265 .map(|chunk| {
266 let mut local_sums = vec![0u64; ncols];
267 for row_idx in chunk {
268 let row = matrix.row(row_idx);
269 for (&col_idx, &val) in row.col_indices().iter().zip(row.values()) {
270 local_sums[col_idx] = local_sums[col_idx].saturating_add(val as u64);
271 }
272 }
273 local_sums
274 })
275 .reduce(
276 || vec![0u64; ncols],
277 |mut acc, local| {
278 for (i, val) in local.into_iter().enumerate() {
279 acc[i] = acc[i].saturating_add(val);
280 }
281 acc
282 },
283 )
284 .into_iter()
285 .map(|sum| (sum.min(u32::MAX as u64)) as u32)
286 .collect()
287 }
288
289 pub fn element_wise_multiply(a: &CsrMatrix<u32>, b: &CsrMatrix<u8>) -> Result<CsrMatrix<u32>> {
296 if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
297 return Err(RedicatError::DimensionMismatch {
298 expected: format!("{}×{}", a.nrows(), a.ncols()),
299 actual: format!("{}×{}", b.nrows(), b.ncols()),
300 });
301 }
302
303 let triplets: Vec<(usize, usize, u32)> = (0..a.nrows())
306 .into_par_iter()
307 .flat_map(|row_idx| {
308 let a_row = a.row(row_idx);
309 let b_row = b.row(row_idx);
310
311 let a_cols = a_row.col_indices();
312 let a_vals = a_row.values();
313 let b_cols = b_row.col_indices();
314 let b_vals = b_row.values();
315
316 let mut result: SmallVec<[(usize, usize, u32); 32]> = SmallVec::new();
319 let mut a_idx = 0;
320 let mut b_idx = 0;
321
322 while a_idx < a_cols.len() && b_idx < b_cols.len() {
323 let a_col = a_cols[a_idx];
324 let b_col = b_cols[b_idx];
325
326 match a_col.cmp(&b_col) {
327 std::cmp::Ordering::Equal => {
328 if b_vals[b_idx] > 0 {
330 result.push((row_idx, a_col, a_vals[a_idx]));
331 }
332 a_idx += 1;
333 b_idx += 1;
334 }
335 std::cmp::Ordering::Less => {
336 a_idx += 1;
338 }
339 std::cmp::Ordering::Greater => {
340 b_idx += 1;
342 }
343 }
344 }
345
346 result.into_vec()
348 })
349 .collect();
350
351 Self::from_triplets_u32(a.nrows(), a.ncols(), triplets)
352 }
353
354 pub fn transpose_u32(matrix: &CsrMatrix<u32>) -> CsrMatrix<u32> {
356 matrix.transpose()
357 }
358
359 pub fn matrix_vector_multiply(matrix: &CsrMatrix<u32>, vector: &[u32]) -> Result<Vec<u32>> {
361 if matrix.ncols() != vector.len() {
362 return Err(RedicatError::DimensionMismatch {
363 expected: format!("vector length = {}", matrix.ncols()),
364 actual: format!("vector length = {}", vector.len()),
365 });
366 }
367
368 let mut result = vec![0u64; matrix.nrows()];
369
370 result
372 .par_iter_mut()
373 .enumerate()
374 .for_each(|(row_idx, result_val)| {
375 let row = matrix.row(row_idx);
376 *result_val = row.col_indices().iter().zip(row.values()).fold(
377 0u64,
378 |acc, (&col_idx, &mat_val)| {
379 acc.saturating_add((mat_val as u64) * (vector[col_idx] as u64))
380 },
381 );
382 });
383
384 Ok(result
385 .into_iter()
386 .map(|val| (val.min(u32::MAX as u64)) as u32)
387 .collect())
388 }
389
390 pub fn get_density_stats(matrix: &CsrMatrix<u32>) -> (f64, usize, usize) {
392 let total_elements = matrix.nrows() * matrix.ncols();
393 let nnz = matrix.nnz();
394 let density = if total_elements > 0 {
395 nnz as f64 / total_elements as f64
396 } else {
397 0.0
398 };
399 (density, nnz, total_elements)
400 }
401
402 pub fn spill_to_file(matrix: &CsrMatrix<u32>, path: &Path) -> Result<()> {
411 let mut file = std::fs::File::create(path).map_err(RedicatError::Io)?;
412 let (row_offsets, col_indices, values) = matrix.csr_data();
413 let nrows = matrix.nrows() as u64;
414 let ncols = matrix.ncols() as u64;
415 let nnz = matrix.nnz() as u64;
416
417 file.write_all(&nrows.to_le_bytes()).map_err(RedicatError::Io)?;
418 file.write_all(&ncols.to_le_bytes()).map_err(RedicatError::Io)?;
419 file.write_all(&nnz.to_le_bytes()).map_err(RedicatError::Io)?;
420
421 for &offset in row_offsets {
422 file.write_all(&(offset as u64).to_le_bytes()).map_err(RedicatError::Io)?;
423 }
424 for &col in col_indices {
425 file.write_all(&(col as u64).to_le_bytes()).map_err(RedicatError::Io)?;
426 }
427 let value_bytes: &[u8] = unsafe {
429 std::slice::from_raw_parts(
430 values.as_ptr() as *const u8,
431 values.len() * std::mem::size_of::<u32>(),
432 )
433 };
434 file.write_all(value_bytes).map_err(RedicatError::Io)?;
435 file.flush().map_err(RedicatError::Io)?;
436 Ok(())
437 }
438
439 pub fn load_from_file(path: &Path) -> Result<CsrMatrix<u32>> {
441 let mut file = std::fs::File::open(path).map_err(RedicatError::Io)?;
442
443 let mut buf8 = [0u8; 8];
444 let read_u64 = |f: &mut std::fs::File, b: &mut [u8; 8]| -> Result<u64> {
445 f.read_exact(b).map_err(RedicatError::Io)?;
446 Ok(u64::from_le_bytes(*b))
447 };
448
449 let nrows = read_u64(&mut file, &mut buf8)? as usize;
450 let ncols = read_u64(&mut file, &mut buf8)? as usize;
451 let nnz = read_u64(&mut file, &mut buf8)? as usize;
452
453 let mut row_offsets = Vec::with_capacity(nrows + 1);
454 for _ in 0..=nrows {
455 row_offsets.push(read_u64(&mut file, &mut buf8)? as usize);
456 }
457
458 let mut col_indices = Vec::with_capacity(nnz);
459 for _ in 0..nnz {
460 col_indices.push(read_u64(&mut file, &mut buf8)? as usize);
461 }
462
463 let mut value_bytes = vec![0u8; nnz * std::mem::size_of::<u32>()];
464 file.read_exact(&mut value_bytes).map_err(RedicatError::Io)?;
465 let values: Vec<u32> = value_bytes
466 .chunks_exact(4)
467 .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
468 .collect();
469
470 CsrMatrix::try_from_csr_data(nrows, ncols, row_offsets, col_indices, values).map_err(
471 |e| RedicatError::SparseMatrix(format!("Failed to load spilled matrix: {:?}", e)),
472 )
473 }
474
475 pub fn estimate_csr_bytes(matrix: &CsrMatrix<u32>) -> usize {
477 let (row_offsets, col_indices, values) = matrix.csr_data();
478 row_offsets.len() * std::mem::size_of::<usize>()
479 + col_indices.len() * std::mem::size_of::<usize>()
480 + values.len() * std::mem::size_of::<u32>()
481 }
482}
483
484pub trait SparseMatrixExt<T> {
486 fn apply_threshold(&self, threshold: T) -> CsrMatrix<T>
487 where
488 T: Copy + PartialOrd + Default + nalgebra::Scalar;
489}
490
491impl SparseMatrixExt<u32> for CsrMatrix<u32> {
492 fn apply_threshold(&self, threshold: u32) -> CsrMatrix<u32> {
494 let triplets: Vec<(usize, usize, u32)> = self
495 .triplet_iter()
496 .filter_map(|(row, col, &val)| {
497 if val >= threshold {
498 Some((row, col, val))
499 } else {
500 None
501 }
502 })
503 .collect();
504
505 SparseOps::from_triplets_u32(self.nrows(), self.ncols(), triplets)
506 .unwrap_or_else(|_| CsrMatrix::zeros(self.nrows(), self.ncols()))
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 fn matrix_value(matrix: &CsrMatrix<u32>, row: usize, col: usize) -> u32 {
515 let row_view = matrix.row(row);
516 row_view
517 .col_indices()
518 .iter()
519 .zip(row_view.values())
520 .find_map(|(&col_idx, &value)| (col_idx == col).then_some(value))
521 .unwrap_or(0)
522 }
523
524 #[test]
525 fn test_parallel_sum_two_matrices() {
526 let m1 = SparseOps::from_triplets_u32(3, 3, vec![
527 (0, 0, 1), (0, 1, 2),
528 (1, 1, 3), (1, 2, 4),
529 (2, 0, 5), (2, 2, 6),
530 ]).unwrap();
531
532 let m2 = SparseOps::from_triplets_u32(3, 3, vec![
533 (0, 0, 10), (0, 2, 20),
534 (1, 1, 30),
535 (2, 1, 40), (2, 2, 50),
536 ]).unwrap();
537
538 let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]).unwrap();
539
540 assert_eq!(matrix_value(&result, 0, 0), 11); assert_eq!(matrix_value(&result, 0, 1), 2); assert_eq!(matrix_value(&result, 0, 2), 20); assert_eq!(matrix_value(&result, 1, 1), 33); assert_eq!(matrix_value(&result, 1, 2), 4); assert_eq!(matrix_value(&result, 2, 0), 5); assert_eq!(matrix_value(&result, 2, 1), 40); assert_eq!(matrix_value(&result, 2, 2), 56); }
549
550 #[test]
551 fn test_parallel_sum_multiple_matrices() {
552 let m1 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 1), (1, 1, 2)]).unwrap();
553 let m2 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 3), (0, 1, 4)]).unwrap();
554 let m3 = SparseOps::from_triplets_u32(2, 2, vec![(1, 0, 5), (1, 1, 6)]).unwrap();
555 let m4 = SparseOps::from_triplets_u32(2, 2, vec![(0, 1, 7), (1, 0, 8)]).unwrap();
556
557 let result = SparseOps::parallel_sum_matrices(&[&m1, &m2, &m3, &m4]).unwrap();
558
559 assert_eq!(matrix_value(&result, 0, 0), 4); assert_eq!(matrix_value(&result, 0, 1), 11); assert_eq!(matrix_value(&result, 1, 0), 13); assert_eq!(matrix_value(&result, 1, 1), 8); }
564
565 #[test]
566 fn test_parallel_sum_eight_matrices() {
567 let matrices: Vec<CsrMatrix<u32>> = (0..8)
569 .map(|i| {
570 SparseOps::from_triplets_u32(2, 2, vec![
571 (0, 0, i + 1),
572 (1, 1, i + 1),
573 ]).unwrap()
574 })
575 .collect();
576
577 let matrix_refs: Vec<&CsrMatrix<u32>> = matrices.iter().collect();
578 let result = SparseOps::parallel_sum_matrices(&matrix_refs).unwrap();
579
580 assert_eq!(matrix_value(&result, 0, 0), 36);
582 assert_eq!(matrix_value(&result, 1, 1), 36);
583 assert_eq!(matrix_value(&result, 0, 1), 0);
584 assert_eq!(matrix_value(&result, 1, 0), 0);
585 }
586
587 #[test]
590 fn test_from_triplets_empty_input() {
591 let m = SparseOps::from_triplets_u32(5, 5, vec![]).unwrap();
592 assert_eq!(m.nrows(), 5);
593 assert_eq!(m.ncols(), 5);
594 assert_eq!(m.nnz(), 0);
595 }
596
597 #[test]
598 fn test_from_triplets_zero_dimensions() {
599 let m = SparseOps::from_triplets_u32(0, 0, vec![]).unwrap();
600 assert_eq!(m.nrows(), 0);
601 assert_eq!(m.ncols(), 0);
602 }
603
604 #[test]
605 fn test_from_triplets_out_of_bounds() {
606 let result = SparseOps::from_triplets_u32(2, 2, vec![(3, 0, 1)]);
607 assert!(result.is_err());
608 }
609
610 #[test]
611 fn test_from_triplets_duplicate_entries_summed() {
612 let m = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 3), (0, 0, 7)]).unwrap();
614 assert_eq!(matrix_value(&m, 0, 0), 10);
615 }
616
617 #[test]
618 fn test_add_matrices_dimension_mismatch() {
619 let a = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
620 let b = SparseOps::from_triplets_u32(3, 2, vec![(0, 0, 1)]).unwrap();
621 assert!(SparseOps::add_matrices(&a, &b).is_err());
622 }
623
624 #[test]
625 fn test_add_matrices_one_empty() {
626 let a = SparseOps::from_triplets_u32(3, 3, vec![(0, 0, 5), (2, 2, 10)]).unwrap();
627 let b = CsrMatrix::<u32>::zeros(3, 3);
628 let result = SparseOps::add_matrices(&a, &b).unwrap();
629 assert_eq!(matrix_value(&result, 0, 0), 5);
630 assert_eq!(matrix_value(&result, 2, 2), 10);
631 assert_eq!(result.nnz(), 2);
632 }
633
634 #[test]
635 fn test_filter_columns_keeps_correct_subset() {
636 let m = SparseOps::from_triplets_u32(2, 4, vec![
637 (0, 0, 1), (0, 1, 2), (0, 2, 3), (0, 3, 4),
638 (1, 0, 5), (1, 1, 6), (1, 2, 7), (1, 3, 8),
639 ]).unwrap();
640
641 let filtered = SparseOps::filter_columns_u32(&m, &[1, 3]).unwrap();
642 assert_eq!(filtered.nrows(), 2);
643 assert_eq!(filtered.ncols(), 2);
644 assert_eq!(matrix_value(&filtered, 0, 0), 2); assert_eq!(matrix_value(&filtered, 0, 1), 4); assert_eq!(matrix_value(&filtered, 1, 0), 6);
647 assert_eq!(matrix_value(&filtered, 1, 1), 8);
648 }
649
650 #[test]
651 fn test_filter_columns_empty_keep() {
652 let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
653 let filtered = SparseOps::filter_columns_u32(&m, &[]).unwrap();
654 assert_eq!(filtered.ncols(), 0);
655 assert_eq!(filtered.nnz(), 0);
656 }
657
658 #[test]
659 fn test_filter_columns_preserves_sparsity() {
660 let m = SparseOps::from_triplets_u32(100, 100, vec![
662 (0, 10, 1), (50, 50, 2), (99, 99, 3),
663 ]).unwrap();
664 let keep: Vec<usize> = (0..50).collect();
665 let filtered = SparseOps::filter_columns_u32(&m, &keep).unwrap();
666 assert_eq!(filtered.ncols(), 50);
667 assert_eq!(filtered.nnz(), 1);
669 assert_eq!(matrix_value(&filtered, 0, 10), 1);
670 }
671
672 #[test]
673 fn test_compute_row_sums_basic() {
674 let m = SparseOps::from_triplets_u32(3, 3, vec![
675 (0, 0, 1), (0, 1, 2), (0, 2, 3),
676 (1, 1, 10),
677 (2, 0, 5), (2, 2, 5),
678 ]).unwrap();
679 assert_eq!(SparseOps::compute_row_sums(&m), vec![6, 10, 10]);
680 }
681
682 #[test]
683 fn test_compute_row_sums_empty_matrix() {
684 let m = CsrMatrix::<u32>::zeros(3, 4);
685 assert_eq!(SparseOps::compute_row_sums(&m), vec![0, 0, 0]);
686 }
687
688 #[test]
689 fn test_compute_col_sums_basic() {
690 let m = SparseOps::from_triplets_u32(3, 3, vec![
691 (0, 0, 1), (1, 0, 2), (2, 0, 3),
692 (0, 2, 10), (1, 2, 20),
693 ]).unwrap();
694 assert_eq!(SparseOps::compute_col_sums(&m), vec![6, 0, 30]);
695 }
696
697 #[test]
698 fn test_compute_col_sums_empty() {
699 let m = CsrMatrix::<u32>::zeros(2, 5);
700 assert_eq!(SparseOps::compute_col_sums(&m), vec![0, 0, 0, 0, 0]);
701 }
702
703 #[test]
704 fn test_compute_masked_row_sums_basic() {
705 let m = SparseOps::from_triplets_u32(2, 4, vec![
706 (0, 0, 10), (0, 1, 20), (0, 2, 30), (0, 3, 40),
707 (1, 0, 1), (1, 1, 2), (1, 2, 3), (1, 3, 4),
708 ]).unwrap();
709 let mask = vec![true, false, true, false];
710 let sums = SparseOps::compute_masked_row_sums(&m, &mask);
711 assert_eq!(sums, vec![40, 4]); }
713
714 #[test]
715 fn test_compute_masked_row_sums_all_false() {
716 let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 99)]).unwrap();
717 let mask = vec![false, false, false];
718 assert_eq!(SparseOps::compute_masked_row_sums(&m, &mask), vec![0, 0]);
719 }
720
721 #[test]
722 fn test_compute_masked_row_sums_wrong_length() {
723 let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
724 let mask = vec![true, false]; assert_eq!(SparseOps::compute_masked_row_sums(&m, &mask), vec![0, 0]);
727 }
728
729 #[test]
730 fn test_element_wise_multiply_basic() {
731 let a = SparseOps::from_triplets_u32(2, 2, vec![
732 (0, 0, 10), (0, 1, 20), (1, 0, 30), (1, 1, 40),
733 ]).unwrap();
734 let b = SparseOps::from_triplets(2, 2, vec![
735 (0, 0, 1), (0, 1, 0), (1, 1, 1),
736 ]).unwrap();
737 let result = SparseOps::element_wise_multiply(&a, &b).unwrap();
738 assert_eq!(matrix_value(&result, 0, 0), 10);
739 assert_eq!(matrix_value(&result, 0, 1), 0); assert_eq!(matrix_value(&result, 1, 0), 0); assert_eq!(matrix_value(&result, 1, 1), 40);
742 }
743
744 #[test]
745 fn test_element_wise_multiply_dimension_mismatch() {
746 let a = SparseOps::from_triplets_u32(2, 3, vec![]).unwrap();
747 let b = SparseOps::from_triplets(3, 2, vec![]).unwrap();
748 assert!(SparseOps::element_wise_multiply(&a, &b).is_err());
749 }
750
751 #[test]
752 fn test_transpose_basic() {
753 let m = SparseOps::from_triplets_u32(2, 3, vec![
754 (0, 0, 1), (0, 2, 2), (1, 1, 3),
755 ]).unwrap();
756 let t = SparseOps::transpose_u32(&m);
757 assert_eq!(t.nrows(), 3);
758 assert_eq!(t.ncols(), 2);
759 assert_eq!(matrix_value(&t, 0, 0), 1);
760 assert_eq!(matrix_value(&t, 2, 0), 2);
761 assert_eq!(matrix_value(&t, 1, 1), 3);
762 }
763
764 #[test]
765 fn test_matrix_vector_multiply() {
766 let m = SparseOps::from_triplets_u32(2, 3, vec![
767 (0, 0, 1), (0, 1, 2), (0, 2, 3),
768 (1, 0, 4), (1, 1, 5), (1, 2, 6),
769 ]).unwrap();
770 let v = vec![1, 10, 100];
771 let result = SparseOps::matrix_vector_multiply(&m, &v).unwrap();
772 assert_eq!(result, vec![321, 654]);
773 }
774
775 #[test]
776 fn test_matrix_vector_multiply_dimension_mismatch() {
777 let m = SparseOps::from_triplets_u32(2, 3, vec![]).unwrap();
778 assert!(SparseOps::matrix_vector_multiply(&m, &[1, 2]).is_err());
779 }
780
781 #[test]
782 fn test_density_stats() {
783 let m = SparseOps::from_triplets_u32(10, 10, vec![
784 (0, 0, 1), (5, 5, 2), (9, 9, 3),
785 ]).unwrap();
786 let (density, nnz, total) = SparseOps::get_density_stats(&m);
787 assert_eq!(nnz, 3);
788 assert_eq!(total, 100);
789 assert!((density - 0.03).abs() < 1e-10);
790 }
791
792 #[test]
793 fn test_apply_threshold() {
794 let m = SparseOps::from_triplets_u32(3, 3, vec![
795 (0, 0, 1), (0, 1, 5), (1, 1, 10), (2, 2, 3),
796 ]).unwrap();
797 let filtered = m.apply_threshold(5);
798 assert_eq!(matrix_value(&filtered, 0, 0), 0); assert_eq!(matrix_value(&filtered, 0, 1), 5); assert_eq!(matrix_value(&filtered, 1, 1), 10); assert_eq!(matrix_value(&filtered, 2, 2), 0); }
803
804 #[test]
807 fn test_parallel_sum_single_matrix() {
808 let m = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 42), (1, 1, 24)]).unwrap();
809 let result = SparseOps::parallel_sum_matrices(&[&m]).unwrap();
810
811 assert_eq!(matrix_value(&result, 0, 0), 42);
812 assert_eq!(matrix_value(&result, 1, 1), 24);
813 }
814
815 #[test]
816 fn test_parallel_sum_preserves_sparsity() {
817 let m1 = SparseOps::from_triplets_u32(100, 100, vec![
819 (0, 0, 1), (10, 10, 2), (50, 50, 3)
820 ]).unwrap();
821
822 let m2 = SparseOps::from_triplets_u32(100, 100, vec![
823 (0, 0, 10), (20, 20, 20), (50, 50, 30)
824 ]).unwrap();
825
826 let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]).unwrap();
827
828 assert!(result.nnz() <= 6);
830 assert_eq!(matrix_value(&result, 0, 0), 11);
831 assert_eq!(matrix_value(&result, 10, 10), 2);
832 assert_eq!(matrix_value(&result, 20, 20), 20);
833 assert_eq!(matrix_value(&result, 50, 50), 33);
834 }
835
836 #[test]
837 fn test_parallel_sum_dimension_mismatch() {
838 let m1 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 1)]).unwrap();
839 let m2 = SparseOps::from_triplets_u32(3, 3, vec![(0, 0, 1)]).unwrap();
840
841 let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]);
842 assert!(result.is_err());
843 }
844
845 #[test]
846 fn test_parallel_sum_empty_list() {
847 let result = SparseOps::parallel_sum_matrices(&[]);
848 assert!(result.is_err());
849 }
850
851 #[test]
852 fn test_parallel_sum_large_scale() {
853 let n_matrices = 8;
856 let size = 1000;
857 let density = 0.01;
858 let n_nonzeros = (size as f64 * size as f64 * density) as usize;
859
860 let matrices: Vec<CsrMatrix<u32>> = (0..n_matrices)
861 .map(|matrix_idx| {
862 let triplets: Vec<(usize, usize, u32)> = (0..n_nonzeros)
863 .map(|i| {
864 let row = (i * 7 + matrix_idx * 13) % size;
865 let col = (i * 11 + matrix_idx * 17) % size;
866 (row, col, 1)
867 })
868 .collect();
869 SparseOps::from_triplets_u32(size, size, triplets).unwrap()
870 })
871 .collect();
872
873 let matrix_refs: Vec<&CsrMatrix<u32>> = matrices.iter().collect();
874 let result = SparseOps::parallel_sum_matrices(&matrix_refs).unwrap();
875
876 assert_eq!(result.nrows(), size);
878 assert_eq!(result.ncols(), size);
879
880 let density_result = result.nnz() as f64 / (size * size) as f64;
882 assert!(density_result < 0.1, "Result should maintain sparsity");
883 }
884
885 #[test]
888 fn test_spill_and_load_roundtrip() {
889 let m = SparseOps::from_triplets_u32(3, 4, vec![
890 (0, 0, 1), (0, 3, 42),
891 (1, 1, 100),
892 (2, 2, 7), (2, 3, 99),
893 ]).unwrap();
894
895 let dir = tempfile::tempdir().unwrap();
896 let path = dir.path().join("matrix.bin");
897
898 SparseOps::spill_to_file(&m, &path).unwrap();
899 let loaded = SparseOps::load_from_file(&path).unwrap();
900
901 assert_eq!(loaded.nrows(), 3);
902 assert_eq!(loaded.ncols(), 4);
903 assert_eq!(loaded.nnz(), 5);
904 assert_eq!(matrix_value(&loaded, 0, 0), 1);
905 assert_eq!(matrix_value(&loaded, 0, 3), 42);
906 assert_eq!(matrix_value(&loaded, 1, 1), 100);
907 assert_eq!(matrix_value(&loaded, 2, 2), 7);
908 assert_eq!(matrix_value(&loaded, 2, 3), 99);
909 }
910
911 #[test]
912 fn test_spill_and_load_empty_matrix() {
913 let m = CsrMatrix::<u32>::zeros(5, 10);
914 let dir = tempfile::tempdir().unwrap();
915 let path = dir.path().join("empty.bin");
916
917 SparseOps::spill_to_file(&m, &path).unwrap();
918 let loaded = SparseOps::load_from_file(&path).unwrap();
919
920 assert_eq!(loaded.nrows(), 5);
921 assert_eq!(loaded.ncols(), 10);
922 assert_eq!(loaded.nnz(), 0);
923 }
924
925 #[test]
926 fn test_spill_and_load_large_values() {
927 let m = SparseOps::from_triplets_u32(1, 2, vec![
928 (0, 0, u32::MAX), (0, 1, u32::MAX - 1),
929 ]).unwrap();
930 let dir = tempfile::tempdir().unwrap();
931 let path = dir.path().join("large.bin");
932
933 SparseOps::spill_to_file(&m, &path).unwrap();
934 let loaded = SparseOps::load_from_file(&path).unwrap();
935
936 assert_eq!(matrix_value(&loaded, 0, 0), u32::MAX);
937 assert_eq!(matrix_value(&loaded, 0, 1), u32::MAX - 1);
938 }
939
940 #[test]
941 fn test_estimate_csr_bytes_nonzero() {
942 let m = SparseOps::from_triplets_u32(10, 10, vec![
943 (0, 0, 1), (5, 5, 2), (9, 9, 3),
944 ]).unwrap();
945 let bytes = SparseOps::estimate_csr_bytes(&m);
946 assert!(bytes >= 100, "Expected >= 100 bytes, got {}", bytes);
948 }
949}