Skip to main content

redicat_lib/core/
sparse.rs

1//! High-performance sparse matrix utilities shared across REDICAT
2
3use crate::core::error::{RedicatError, Result};
4use itertools::Itertools;
5use nalgebra_sparse::ops::serial::spadd_csr_prealloc;
6use nalgebra_sparse::ops::Op;
7use nalgebra_sparse::{CooMatrix, CsrMatrix};
8use rayon::prelude::*;
9use rustc_hash::FxHashMap;
10use smallvec::SmallVec;
11use std::io::{Read, Write};
12use std::path::Path;
13
14pub struct SparseOps;
15
16impl SparseOps {
17    /// Create CSR matrix from COO format using nalgebra_sparse native conversion
18    pub fn from_triplets_u32(
19        nrows: usize,
20        ncols: usize,
21        triplets: Vec<(usize, usize, u32)>,
22    ) -> Result<CsrMatrix<u32>> {
23        if nrows == 0 || ncols == 0 {
24            return Ok(CsrMatrix::zeros(nrows, ncols));
25        }
26
27        if triplets.is_empty() {
28            return Ok(CsrMatrix::zeros(nrows, ncols));
29        }
30
31        // Validate indices
32        for &(row, col, _) in &triplets {
33            if row >= nrows || col >= ncols {
34                return Err(RedicatError::InvalidInput(format!(
35                    "Index ({}, {}) exceeds matrix dimensions ({}, {})",
36                    row, col, nrows, ncols
37                )));
38            }
39        }
40
41        // Use COO format first, then convert to CSR using native nalgebra_sparse
42        let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
43            triplets.into_iter().multiunzip();
44
45        let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
46            .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
47
48        // Use native conversion from COO to CSR
49        let csr = CsrMatrix::from(&coo);
50        Ok(csr)
51    }
52
53    /// Create CSR matrix from triplets with u8 values
54    pub fn from_triplets(
55        nrows: usize,
56        ncols: usize,
57        triplets: Vec<(usize, usize, u8)>,
58    ) -> Result<CsrMatrix<u8>> {
59        if nrows == 0 || ncols == 0 {
60            return Ok(CsrMatrix::zeros(nrows, ncols));
61        }
62
63        if triplets.is_empty() {
64            return Ok(CsrMatrix::zeros(nrows, ncols));
65        }
66
67        let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
68            triplets.into_iter().multiunzip();
69
70        let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
71            .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
72
73        Ok(CsrMatrix::from(&coo))
74    }
75
76    /// Highly optimized matrix addition using nalgebra_sparse's spadd operation
77    pub fn add_matrices(a: &CsrMatrix<u32>, b: &CsrMatrix<u32>) -> Result<CsrMatrix<u32>> {
78        if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
79            return Err(RedicatError::DimensionMismatch {
80                expected: format!("{}×{}", a.nrows(), a.ncols()),
81                actual: format!("{}×{}", b.nrows(), b.ncols()),
82            });
83        }
84
85        // Use nalgebra_sparse's native sparse addition pattern computation
86        let pattern = nalgebra_sparse::ops::serial::spadd_pattern(a.pattern(), b.pattern());
87
88        // Pre-allocate result matrix with computed pattern
89        let mut result =
90            CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![0u32; pattern.nnz()])
91                .map_err(|e| {
92                    RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e))
93                })?;
94
95        // Use native sparse addition operation with correct API
96        // API signature: spadd_csr_prealloc(beta, C, alpha, Op<A>)
97        spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(a))
98            .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
99
100        spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(b))
101            .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
102
103        Ok(result)
104    }
105
106    /// Sum multiple sparse matrices into a single result using a single-allocation
107    /// fold strategy optimized for ≤8 matrices.
108    ///
109    /// Instead of the previous tree-reduction that cloned every matrix at each
110    /// level, this implementation:
111    /// 1. Computes the *union* sparsity pattern across all input matrices.
112    /// 2. Allocates **one** result matrix with that pattern (values zeroed).
113    /// 3. Folds each input matrix into the result via `spadd_csr_prealloc`
114    ///    which writes directly into the pre-allocated buffer — zero additional
115    ///    matrix allocations.
116    ///
117    /// For the typical call-pipeline case (≤8 base-count layers) this reduces
118    /// peak transient memory from O(n) matrix copies to exactly 1 matrix.
119    pub fn parallel_sum_matrices(matrices: &[&CsrMatrix<u32>]) -> Result<CsrMatrix<u32>> {
120        if matrices.is_empty() {
121            return Err(RedicatError::EmptyData("No matrices to sum".to_string()));
122        }
123
124        if matrices.len() == 1 {
125            return Ok(matrices[0].clone());
126        }
127
128        // Verify all matrices have the same dimensions
129        let (nrows, ncols) = (matrices[0].nrows(), matrices[0].ncols());
130        for matrix in matrices.iter().skip(1) {
131            if matrix.nrows() != nrows || matrix.ncols() != ncols {
132                return Err(RedicatError::DimensionMismatch {
133                    expected: format!("{}×{}", nrows, ncols),
134                    actual: format!("{}×{}", matrix.nrows(), matrix.ncols()),
135                });
136            }
137        }
138
139        // Compute union sparsity pattern by folding spadd_pattern across all inputs.
140        let mut union_pattern = matrices[0].pattern().clone();
141        for matrix in matrices.iter().skip(1) {
142            union_pattern =
143                nalgebra_sparse::ops::serial::spadd_pattern(&union_pattern, matrix.pattern());
144        }
145
146        // Allocate the single result matrix with the union pattern, all zeros.
147        let nnz = union_pattern.nnz();
148        let mut result =
149            CsrMatrix::try_from_pattern_and_values(union_pattern, vec![0u32; nnz]).map_err(
150                |e| RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e)),
151            )?;
152
153        // Fold each input into the result in-place.
154        for matrix in matrices {
155            spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(*matrix)).map_err(|e| {
156                RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e))
157            })?;
158        }
159
160        Ok(result)
161    }
162
163    /// Optimized column filtering using sparsity pattern operations
164    pub fn filter_columns_u32(
165        matrix: &CsrMatrix<u32>,
166        keep_indices: &[usize],
167    ) -> Result<CsrMatrix<u32>> {
168        let nrows = matrix.nrows();
169        let new_ncols = keep_indices.len();
170
171        if new_ncols == 0 {
172            return Ok(CsrMatrix::zeros(nrows, 0));
173        }
174
175        // Create efficient column mapping using FxHashMap for better performance
176        let col_map: FxHashMap<usize, usize> = keep_indices
177            .iter()
178            .enumerate()
179            .map(|(new_idx, &old_idx)| (old_idx, new_idx))
180            .collect();
181
182        // Build new sparsity pattern efficiently
183        let mut new_row_offsets = Vec::with_capacity(nrows + 1);
184        let mut new_col_indices = Vec::new();
185        let mut new_values = Vec::new();
186
187        new_row_offsets.push(0);
188
189        for row_idx in 0..nrows {
190            let row = matrix.row(row_idx);
191
192            for (&old_col, &val) in row.col_indices().iter().zip(row.values()) {
193                if let Some(&new_col) = col_map.get(&old_col) {
194                    new_col_indices.push(new_col);
195                    new_values.push(val);
196                }
197            }
198
199            new_row_offsets.push(new_col_indices.len());
200        }
201
202        // Create CSR matrix directly using nalgebra_sparse
203        CsrMatrix::try_from_csr_data(
204            nrows,
205            new_ncols,
206            new_row_offsets,
207            new_col_indices,
208            new_values,
209        )
210        .map_err(|e| {
211            RedicatError::SparseMatrix(format!("Failed to create filtered matrix: {:?}", e))
212        })
213    }
214
215    /// Optimized row sums using native CSR structure access
216    pub fn compute_row_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
217        (0..matrix.nrows())
218            .into_par_iter()
219            .map(|row_idx| {
220                let row = matrix.row(row_idx);
221                row.values()
222                    .iter()
223                    .fold(0u64, |acc, &val| acc.saturating_add(val as u64))
224                    .min(u32::MAX as u64) as u32
225            })
226            .collect()
227    }
228
229    /// Row sums restricted to the columns flagged by `mask`
230    pub fn compute_masked_row_sums(matrix: &CsrMatrix<u32>, mask: &[bool]) -> Vec<u32> {
231        let mask_len = mask.len();
232        if matrix.ncols() != mask_len {
233            return vec![0; matrix.nrows()];
234        }
235
236        (0..matrix.nrows())
237            .into_par_iter()
238            .map(|row_idx| {
239                let row = matrix.row(row_idx);
240                row.col_indices()
241                    .iter()
242                    .zip(row.values())
243                    .fold(0u64, |acc, (&col_idx, &val)| {
244                        if mask[col_idx] {
245                            acc.saturating_add(val as u64)
246                        } else {
247                            acc
248                        }
249                    })
250                    .min(u32::MAX as u64) as u32
251            })
252            .collect()
253    }
254
255    /// Optimized column sums using parallel reduction over CSR structure
256    pub fn compute_col_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
257        let ncols = matrix.ncols();
258
259        // Use parallel reduction with chunked processing
260        let chunk_size = std::cmp::max(1, matrix.nrows() / rayon::current_num_threads());
261
262        (0..matrix.nrows())
263            .into_par_iter()
264            .chunks(chunk_size)
265            .map(|chunk| {
266                let mut local_sums = vec![0u64; ncols];
267                for row_idx in chunk {
268                    let row = matrix.row(row_idx);
269                    for (&col_idx, &val) in row.col_indices().iter().zip(row.values()) {
270                        local_sums[col_idx] = local_sums[col_idx].saturating_add(val as u64);
271                    }
272                }
273                local_sums
274            })
275            .reduce(
276                || vec![0u64; ncols],
277                |mut acc, local| {
278                    for (i, val) in local.into_iter().enumerate() {
279                        acc[i] = acc[i].saturating_add(val);
280                    }
281                    acc
282                },
283            )
284            .into_iter()
285            .map(|sum| (sum.min(u32::MAX as u64)) as u32)
286            .collect()
287    }
288
289    /// Element-wise multiplication using optimized two-pointer algorithm for sparse matrices
290    ///
291    /// This implementation uses a sorted two-pointer merge algorithm which is significantly
292    /// faster than HashMap-based intersection for sparse matrices. CSR format guarantees
293    /// column indices are already sorted within each row, allowing O(nnz_a + nnz_b) complexity
294    /// instead of O(nnz_a * log(nnz_b)) with HashMap lookups.
295    pub fn element_wise_multiply(a: &CsrMatrix<u32>, b: &CsrMatrix<u8>) -> Result<CsrMatrix<u32>> {
296        if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
297            return Err(RedicatError::DimensionMismatch {
298                expected: format!("{}×{}", a.nrows(), a.ncols()),
299                actual: format!("{}×{}", b.nrows(), b.ncols()),
300            });
301        }
302
303        // Use parallel processing with two-pointer algorithm for element-wise multiplication
304        // SmallVec optimizes for typical sparse row sizes (most rows have < 32 non-zero elements)
305        let triplets: Vec<(usize, usize, u32)> = (0..a.nrows())
306            .into_par_iter()
307            .flat_map(|row_idx| {
308                let a_row = a.row(row_idx);
309                let b_row = b.row(row_idx);
310
311                let a_cols = a_row.col_indices();
312                let a_vals = a_row.values();
313                let b_cols = b_row.col_indices();
314                let b_vals = b_row.values();
315
316                // Two-pointer algorithm for sorted sparse row intersection
317                // CSR format guarantees column indices are sorted, so we can use merge-like algorithm
318                let mut result: SmallVec<[(usize, usize, u32); 32]> = SmallVec::new();
319                let mut a_idx = 0;
320                let mut b_idx = 0;
321
322                while a_idx < a_cols.len() && b_idx < b_cols.len() {
323                    let a_col = a_cols[a_idx];
324                    let b_col = b_cols[b_idx];
325
326                    match a_col.cmp(&b_col) {
327                        std::cmp::Ordering::Equal => {
328                            // Column indices match - this is an intersection point
329                            if b_vals[b_idx] > 0 {
330                                result.push((row_idx, a_col, a_vals[a_idx]));
331                            }
332                            a_idx += 1;
333                            b_idx += 1;
334                        }
335                        std::cmp::Ordering::Less => {
336                            // a has a column that b doesn't have yet
337                            a_idx += 1;
338                        }
339                        std::cmp::Ordering::Greater => {
340                            // b has a column that a doesn't have yet
341                            b_idx += 1;
342                        }
343                    }
344                }
345
346                // Convert SmallVec to Vec for flat_map compatibility
347                result.into_vec()
348            })
349            .collect();
350
351        Self::from_triplets_u32(a.nrows(), a.ncols(), triplets)
352    }
353
354    /// Transpose operation using nalgebra_sparse native transpose
355    pub fn transpose_u32(matrix: &CsrMatrix<u32>) -> CsrMatrix<u32> {
356        matrix.transpose()
357    }
358
359    /// Matrix-vector multiplication using native spmv operation
360    pub fn matrix_vector_multiply(matrix: &CsrMatrix<u32>, vector: &[u32]) -> Result<Vec<u32>> {
361        if matrix.ncols() != vector.len() {
362            return Err(RedicatError::DimensionMismatch {
363                expected: format!("vector length = {}", matrix.ncols()),
364                actual: format!("vector length = {}", vector.len()),
365            });
366        }
367
368        let mut result = vec![0u64; matrix.nrows()];
369
370        // Use parallel processing for matrix-vector multiplication
371        result
372            .par_iter_mut()
373            .enumerate()
374            .for_each(|(row_idx, result_val)| {
375                let row = matrix.row(row_idx);
376                *result_val = row.col_indices().iter().zip(row.values()).fold(
377                    0u64,
378                    |acc, (&col_idx, &mat_val)| {
379                        acc.saturating_add((mat_val as u64) * (vector[col_idx] as u64))
380                    },
381                );
382            });
383
384        Ok(result
385            .into_iter()
386            .map(|val| (val.min(u32::MAX as u64)) as u32)
387            .collect())
388    }
389
390    /// Get matrix density statistics
391    pub fn get_density_stats(matrix: &CsrMatrix<u32>) -> (f64, usize, usize) {
392        let total_elements = matrix.nrows() * matrix.ncols();
393        let nnz = matrix.nnz();
394        let density = if total_elements > 0 {
395            nnz as f64 / total_elements as f64
396        } else {
397            0.0
398        };
399        (density, nnz, total_elements)
400    }
401
402    /// Serialize a CsrMatrix<u32> to a file in a compact binary format.
403    ///
404    /// Layout: [nrows: u64][ncols: u64][nnz: u64]
405    ///         [row_offsets: (nrows+1) × u64]
406    ///         [col_indices: nnz × u64]
407    ///         [values: nnz × u32]
408    ///
409    /// All integers are written in **little-endian** byte order.
410    pub fn spill_to_file(matrix: &CsrMatrix<u32>, path: &Path) -> Result<()> {
411        let mut file = std::fs::File::create(path).map_err(RedicatError::Io)?;
412        let (row_offsets, col_indices, values) = matrix.csr_data();
413        let nrows = matrix.nrows() as u64;
414        let ncols = matrix.ncols() as u64;
415        let nnz = matrix.nnz() as u64;
416
417        file.write_all(&nrows.to_le_bytes()).map_err(RedicatError::Io)?;
418        file.write_all(&ncols.to_le_bytes()).map_err(RedicatError::Io)?;
419        file.write_all(&nnz.to_le_bytes()).map_err(RedicatError::Io)?;
420
421        for &offset in row_offsets {
422            file.write_all(&(offset as u64).to_le_bytes()).map_err(RedicatError::Io)?;
423        }
424        for &col in col_indices {
425            file.write_all(&(col as u64).to_le_bytes()).map_err(RedicatError::Io)?;
426        }
427        // Write values as a contiguous byte slice for efficiency
428        let value_bytes: &[u8] = unsafe {
429            std::slice::from_raw_parts(
430                values.as_ptr() as *const u8,
431                values.len() * std::mem::size_of::<u32>(),
432            )
433        };
434        file.write_all(value_bytes).map_err(RedicatError::Io)?;
435        file.flush().map_err(RedicatError::Io)?;
436        Ok(())
437    }
438
439    /// Deserialize a CsrMatrix<u32> from a file written by [`spill_to_file`].
440    pub fn load_from_file(path: &Path) -> Result<CsrMatrix<u32>> {
441        let mut file = std::fs::File::open(path).map_err(RedicatError::Io)?;
442
443        let mut buf8 = [0u8; 8];
444        let read_u64 = |f: &mut std::fs::File, b: &mut [u8; 8]| -> Result<u64> {
445            f.read_exact(b).map_err(RedicatError::Io)?;
446            Ok(u64::from_le_bytes(*b))
447        };
448
449        let nrows = read_u64(&mut file, &mut buf8)? as usize;
450        let ncols = read_u64(&mut file, &mut buf8)? as usize;
451        let nnz = read_u64(&mut file, &mut buf8)? as usize;
452
453        let mut row_offsets = Vec::with_capacity(nrows + 1);
454        for _ in 0..=nrows {
455            row_offsets.push(read_u64(&mut file, &mut buf8)? as usize);
456        }
457
458        let mut col_indices = Vec::with_capacity(nnz);
459        for _ in 0..nnz {
460            col_indices.push(read_u64(&mut file, &mut buf8)? as usize);
461        }
462
463        let mut value_bytes = vec![0u8; nnz * std::mem::size_of::<u32>()];
464        file.read_exact(&mut value_bytes).map_err(RedicatError::Io)?;
465        let values: Vec<u32> = value_bytes
466            .chunks_exact(4)
467            .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
468            .collect();
469
470        CsrMatrix::try_from_csr_data(nrows, ncols, row_offsets, col_indices, values).map_err(
471            |e| RedicatError::SparseMatrix(format!("Failed to load spilled matrix: {:?}", e)),
472        )
473    }
474
475    /// Estimate the in-memory byte footprint of a CsrMatrix<u32>.
476    pub fn estimate_csr_bytes(matrix: &CsrMatrix<u32>) -> usize {
477        let (row_offsets, col_indices, values) = matrix.csr_data();
478        row_offsets.len() * std::mem::size_of::<usize>()
479            + col_indices.len() * std::mem::size_of::<usize>()
480            + values.len() * std::mem::size_of::<u32>()
481    }
482}
483
484/// Trait extension for additional sparse matrix operations
485pub trait SparseMatrixExt<T> {
486    fn apply_threshold(&self, threshold: T) -> CsrMatrix<T>
487    where
488        T: Copy + PartialOrd + Default + nalgebra::Scalar;
489}
490
491impl SparseMatrixExt<u32> for CsrMatrix<u32> {
492    /// Apply threshold to sparse matrix values
493    fn apply_threshold(&self, threshold: u32) -> CsrMatrix<u32> {
494        let triplets: Vec<(usize, usize, u32)> = self
495            .triplet_iter()
496            .filter_map(|(row, col, &val)| {
497                if val >= threshold {
498                    Some((row, col, val))
499                } else {
500                    None
501                }
502            })
503            .collect();
504
505        SparseOps::from_triplets_u32(self.nrows(), self.ncols(), triplets)
506            .unwrap_or_else(|_| CsrMatrix::zeros(self.nrows(), self.ncols()))
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    fn matrix_value(matrix: &CsrMatrix<u32>, row: usize, col: usize) -> u32 {
515        let row_view = matrix.row(row);
516        row_view
517            .col_indices()
518            .iter()
519            .zip(row_view.values())
520            .find_map(|(&col_idx, &value)| (col_idx == col).then_some(value))
521            .unwrap_or(0)
522    }
523
524    #[test]
525    fn test_parallel_sum_two_matrices() {
526        let m1 = SparseOps::from_triplets_u32(3, 3, vec![
527            (0, 0, 1), (0, 1, 2),
528            (1, 1, 3), (1, 2, 4),
529            (2, 0, 5), (2, 2, 6),
530        ]).unwrap();
531
532        let m2 = SparseOps::from_triplets_u32(3, 3, vec![
533            (0, 0, 10), (0, 2, 20),
534            (1, 1, 30),
535            (2, 1, 40), (2, 2, 50),
536        ]).unwrap();
537
538        let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]).unwrap();
539
540        assert_eq!(matrix_value(&result, 0, 0), 11);  // 1 + 10
541        assert_eq!(matrix_value(&result, 0, 1), 2);   // 2 + 0
542        assert_eq!(matrix_value(&result, 0, 2), 20);  // 0 + 20
543        assert_eq!(matrix_value(&result, 1, 1), 33);  // 3 + 30
544        assert_eq!(matrix_value(&result, 1, 2), 4);   // 4 + 0
545        assert_eq!(matrix_value(&result, 2, 0), 5);   // 5 + 0
546        assert_eq!(matrix_value(&result, 2, 1), 40);  // 0 + 40
547        assert_eq!(matrix_value(&result, 2, 2), 56);  // 6 + 50
548    }
549
550    #[test]
551    fn test_parallel_sum_multiple_matrices() {
552        let m1 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 1), (1, 1, 2)]).unwrap();
553        let m2 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 3), (0, 1, 4)]).unwrap();
554        let m3 = SparseOps::from_triplets_u32(2, 2, vec![(1, 0, 5), (1, 1, 6)]).unwrap();
555        let m4 = SparseOps::from_triplets_u32(2, 2, vec![(0, 1, 7), (1, 0, 8)]).unwrap();
556
557        let result = SparseOps::parallel_sum_matrices(&[&m1, &m2, &m3, &m4]).unwrap();
558
559        assert_eq!(matrix_value(&result, 0, 0), 4);   // 1 + 3
560        assert_eq!(matrix_value(&result, 0, 1), 11);  // 4 + 7
561        assert_eq!(matrix_value(&result, 1, 0), 13);  // 5 + 8
562        assert_eq!(matrix_value(&result, 1, 1), 8);   // 2 + 6
563    }
564
565    #[test]
566    fn test_parallel_sum_eight_matrices() {
567        // Test with 8 matrices to verify the tree reduction works correctly
568        let matrices: Vec<CsrMatrix<u32>> = (0..8)
569            .map(|i| {
570                SparseOps::from_triplets_u32(2, 2, vec![
571                    (0, 0, i + 1),
572                    (1, 1, i + 1),
573                ]).unwrap()
574            })
575            .collect();
576
577        let matrix_refs: Vec<&CsrMatrix<u32>> = matrices.iter().collect();
578        let result = SparseOps::parallel_sum_matrices(&matrix_refs).unwrap();
579
580        // Sum should be 1+2+3+4+5+6+7+8 = 36
581        assert_eq!(matrix_value(&result, 0, 0), 36);
582        assert_eq!(matrix_value(&result, 1, 1), 36);
583        assert_eq!(matrix_value(&result, 0, 1), 0);
584        assert_eq!(matrix_value(&result, 1, 0), 0);
585    }
586
587    // ===== New comprehensive tests added before refactoring =====
588
589    #[test]
590    fn test_from_triplets_empty_input() {
591        let m = SparseOps::from_triplets_u32(5, 5, vec![]).unwrap();
592        assert_eq!(m.nrows(), 5);
593        assert_eq!(m.ncols(), 5);
594        assert_eq!(m.nnz(), 0);
595    }
596
597    #[test]
598    fn test_from_triplets_zero_dimensions() {
599        let m = SparseOps::from_triplets_u32(0, 0, vec![]).unwrap();
600        assert_eq!(m.nrows(), 0);
601        assert_eq!(m.ncols(), 0);
602    }
603
604    #[test]
605    fn test_from_triplets_out_of_bounds() {
606        let result = SparseOps::from_triplets_u32(2, 2, vec![(3, 0, 1)]);
607        assert!(result.is_err());
608    }
609
610    #[test]
611    fn test_from_triplets_duplicate_entries_summed() {
612        // COO format sums duplicate entries
613        let m = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 3), (0, 0, 7)]).unwrap();
614        assert_eq!(matrix_value(&m, 0, 0), 10);
615    }
616
617    #[test]
618    fn test_add_matrices_dimension_mismatch() {
619        let a = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
620        let b = SparseOps::from_triplets_u32(3, 2, vec![(0, 0, 1)]).unwrap();
621        assert!(SparseOps::add_matrices(&a, &b).is_err());
622    }
623
624    #[test]
625    fn test_add_matrices_one_empty() {
626        let a = SparseOps::from_triplets_u32(3, 3, vec![(0, 0, 5), (2, 2, 10)]).unwrap();
627        let b = CsrMatrix::<u32>::zeros(3, 3);
628        let result = SparseOps::add_matrices(&a, &b).unwrap();
629        assert_eq!(matrix_value(&result, 0, 0), 5);
630        assert_eq!(matrix_value(&result, 2, 2), 10);
631        assert_eq!(result.nnz(), 2);
632    }
633
634    #[test]
635    fn test_filter_columns_keeps_correct_subset() {
636        let m = SparseOps::from_triplets_u32(2, 4, vec![
637            (0, 0, 1), (0, 1, 2), (0, 2, 3), (0, 3, 4),
638            (1, 0, 5), (1, 1, 6), (1, 2, 7), (1, 3, 8),
639        ]).unwrap();
640
641        let filtered = SparseOps::filter_columns_u32(&m, &[1, 3]).unwrap();
642        assert_eq!(filtered.nrows(), 2);
643        assert_eq!(filtered.ncols(), 2);
644        assert_eq!(matrix_value(&filtered, 0, 0), 2); // old col 1 -> new col 0
645        assert_eq!(matrix_value(&filtered, 0, 1), 4); // old col 3 -> new col 1
646        assert_eq!(matrix_value(&filtered, 1, 0), 6);
647        assert_eq!(matrix_value(&filtered, 1, 1), 8);
648    }
649
650    #[test]
651    fn test_filter_columns_empty_keep() {
652        let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
653        let filtered = SparseOps::filter_columns_u32(&m, &[]).unwrap();
654        assert_eq!(filtered.ncols(), 0);
655        assert_eq!(filtered.nnz(), 0);
656    }
657
658    #[test]
659    fn test_filter_columns_preserves_sparsity() {
660        // A 100x100 matrix with only a few nonzeros
661        let m = SparseOps::from_triplets_u32(100, 100, vec![
662            (0, 10, 1), (50, 50, 2), (99, 99, 3),
663        ]).unwrap();
664        let keep: Vec<usize> = (0..50).collect();
665        let filtered = SparseOps::filter_columns_u32(&m, &keep).unwrap();
666        assert_eq!(filtered.ncols(), 50);
667        // Only (0, 10, 1) should survive (col 10 maps to new col 10)
668        assert_eq!(filtered.nnz(), 1);
669        assert_eq!(matrix_value(&filtered, 0, 10), 1);
670    }
671
672    #[test]
673    fn test_compute_row_sums_basic() {
674        let m = SparseOps::from_triplets_u32(3, 3, vec![
675            (0, 0, 1), (0, 1, 2), (0, 2, 3),
676            (1, 1, 10),
677            (2, 0, 5), (2, 2, 5),
678        ]).unwrap();
679        assert_eq!(SparseOps::compute_row_sums(&m), vec![6, 10, 10]);
680    }
681
682    #[test]
683    fn test_compute_row_sums_empty_matrix() {
684        let m = CsrMatrix::<u32>::zeros(3, 4);
685        assert_eq!(SparseOps::compute_row_sums(&m), vec![0, 0, 0]);
686    }
687
688    #[test]
689    fn test_compute_col_sums_basic() {
690        let m = SparseOps::from_triplets_u32(3, 3, vec![
691            (0, 0, 1), (1, 0, 2), (2, 0, 3),
692            (0, 2, 10), (1, 2, 20),
693        ]).unwrap();
694        assert_eq!(SparseOps::compute_col_sums(&m), vec![6, 0, 30]);
695    }
696
697    #[test]
698    fn test_compute_col_sums_empty() {
699        let m = CsrMatrix::<u32>::zeros(2, 5);
700        assert_eq!(SparseOps::compute_col_sums(&m), vec![0, 0, 0, 0, 0]);
701    }
702
703    #[test]
704    fn test_compute_masked_row_sums_basic() {
705        let m = SparseOps::from_triplets_u32(2, 4, vec![
706            (0, 0, 10), (0, 1, 20), (0, 2, 30), (0, 3, 40),
707            (1, 0, 1),  (1, 1, 2),  (1, 2, 3),  (1, 3, 4),
708        ]).unwrap();
709        let mask = vec![true, false, true, false];
710        let sums = SparseOps::compute_masked_row_sums(&m, &mask);
711        assert_eq!(sums, vec![40, 4]); // row0: 10+30, row1: 1+3
712    }
713
714    #[test]
715    fn test_compute_masked_row_sums_all_false() {
716        let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 99)]).unwrap();
717        let mask = vec![false, false, false];
718        assert_eq!(SparseOps::compute_masked_row_sums(&m, &mask), vec![0, 0]);
719    }
720
721    #[test]
722    fn test_compute_masked_row_sums_wrong_length() {
723        let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
724        let mask = vec![true, false]; // wrong length
725        // Should return zeros gracefully
726        assert_eq!(SparseOps::compute_masked_row_sums(&m, &mask), vec![0, 0]);
727    }
728
729    #[test]
730    fn test_element_wise_multiply_basic() {
731        let a = SparseOps::from_triplets_u32(2, 2, vec![
732            (0, 0, 10), (0, 1, 20), (1, 0, 30), (1, 1, 40),
733        ]).unwrap();
734        let b = SparseOps::from_triplets(2, 2, vec![
735            (0, 0, 1), (0, 1, 0), (1, 1, 1),
736        ]).unwrap();
737        let result = SparseOps::element_wise_multiply(&a, &b).unwrap();
738        assert_eq!(matrix_value(&result, 0, 0), 10);
739        assert_eq!(matrix_value(&result, 0, 1), 0); // b has 0 at (0,1)
740        assert_eq!(matrix_value(&result, 1, 0), 0); // b has no entry at (1,0)
741        assert_eq!(matrix_value(&result, 1, 1), 40);
742    }
743
744    #[test]
745    fn test_element_wise_multiply_dimension_mismatch() {
746        let a = SparseOps::from_triplets_u32(2, 3, vec![]).unwrap();
747        let b = SparseOps::from_triplets(3, 2, vec![]).unwrap();
748        assert!(SparseOps::element_wise_multiply(&a, &b).is_err());
749    }
750
751    #[test]
752    fn test_transpose_basic() {
753        let m = SparseOps::from_triplets_u32(2, 3, vec![
754            (0, 0, 1), (0, 2, 2), (1, 1, 3),
755        ]).unwrap();
756        let t = SparseOps::transpose_u32(&m);
757        assert_eq!(t.nrows(), 3);
758        assert_eq!(t.ncols(), 2);
759        assert_eq!(matrix_value(&t, 0, 0), 1);
760        assert_eq!(matrix_value(&t, 2, 0), 2);
761        assert_eq!(matrix_value(&t, 1, 1), 3);
762    }
763
764    #[test]
765    fn test_matrix_vector_multiply() {
766        let m = SparseOps::from_triplets_u32(2, 3, vec![
767            (0, 0, 1), (0, 1, 2), (0, 2, 3),
768            (1, 0, 4), (1, 1, 5), (1, 2, 6),
769        ]).unwrap();
770        let v = vec![1, 10, 100];
771        let result = SparseOps::matrix_vector_multiply(&m, &v).unwrap();
772        assert_eq!(result, vec![321, 654]);
773    }
774
775    #[test]
776    fn test_matrix_vector_multiply_dimension_mismatch() {
777        let m = SparseOps::from_triplets_u32(2, 3, vec![]).unwrap();
778        assert!(SparseOps::matrix_vector_multiply(&m, &[1, 2]).is_err());
779    }
780
781    #[test]
782    fn test_density_stats() {
783        let m = SparseOps::from_triplets_u32(10, 10, vec![
784            (0, 0, 1), (5, 5, 2), (9, 9, 3),
785        ]).unwrap();
786        let (density, nnz, total) = SparseOps::get_density_stats(&m);
787        assert_eq!(nnz, 3);
788        assert_eq!(total, 100);
789        assert!((density - 0.03).abs() < 1e-10);
790    }
791
792    #[test]
793    fn test_apply_threshold() {
794        let m = SparseOps::from_triplets_u32(3, 3, vec![
795            (0, 0, 1), (0, 1, 5), (1, 1, 10), (2, 2, 3),
796        ]).unwrap();
797        let filtered = m.apply_threshold(5);
798        assert_eq!(matrix_value(&filtered, 0, 0), 0); // 1 < 5
799        assert_eq!(matrix_value(&filtered, 0, 1), 5); // 5 >= 5
800        assert_eq!(matrix_value(&filtered, 1, 1), 10); // 10 >= 5
801        assert_eq!(matrix_value(&filtered, 2, 2), 0); // 3 < 5
802    }
803
804    // ===== End of new comprehensive tests =====
805
806    #[test]
807    fn test_parallel_sum_single_matrix() {
808        let m = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 42), (1, 1, 24)]).unwrap();
809        let result = SparseOps::parallel_sum_matrices(&[&m]).unwrap();
810
811        assert_eq!(matrix_value(&result, 0, 0), 42);
812        assert_eq!(matrix_value(&result, 1, 1), 24);
813    }
814
815    #[test]
816    fn test_parallel_sum_preserves_sparsity() {
817        // Create sparse matrices with only a few non-zero elements
818        let m1 = SparseOps::from_triplets_u32(100, 100, vec![
819            (0, 0, 1), (10, 10, 2), (50, 50, 3)
820        ]).unwrap();
821
822        let m2 = SparseOps::from_triplets_u32(100, 100, vec![
823            (0, 0, 10), (20, 20, 20), (50, 50, 30)
824        ]).unwrap();
825
826        let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]).unwrap();
827
828        // Should maintain sparsity - only 5 non-zero elements
829        assert!(result.nnz() <= 6);
830        assert_eq!(matrix_value(&result, 0, 0), 11);
831        assert_eq!(matrix_value(&result, 10, 10), 2);
832        assert_eq!(matrix_value(&result, 20, 20), 20);
833        assert_eq!(matrix_value(&result, 50, 50), 33);
834    }
835
836    #[test]
837    fn test_parallel_sum_dimension_mismatch() {
838        let m1 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 1)]).unwrap();
839        let m2 = SparseOps::from_triplets_u32(3, 3, vec![(0, 0, 1)]).unwrap();
840
841        let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]);
842        assert!(result.is_err());
843    }
844
845    #[test]
846    fn test_parallel_sum_empty_list() {
847        let result = SparseOps::parallel_sum_matrices(&[]);
848        assert!(result.is_err());
849    }
850
851    #[test]
852    fn test_parallel_sum_large_scale() {
853        // Test with a more realistic scenario: 8 matrices of size 1000x1000
854        // with 1% density each
855        let n_matrices = 8;
856        let size = 1000;
857        let density = 0.01;
858        let n_nonzeros = (size as f64 * size as f64 * density) as usize;
859
860        let matrices: Vec<CsrMatrix<u32>> = (0..n_matrices)
861            .map(|matrix_idx| {
862                let triplets: Vec<(usize, usize, u32)> = (0..n_nonzeros)
863                    .map(|i| {
864                        let row = (i * 7 + matrix_idx * 13) % size;
865                        let col = (i * 11 + matrix_idx * 17) % size;
866                        (row, col, 1)
867                    })
868                    .collect();
869                SparseOps::from_triplets_u32(size, size, triplets).unwrap()
870            })
871            .collect();
872
873        let matrix_refs: Vec<&CsrMatrix<u32>> = matrices.iter().collect();
874        let result = SparseOps::parallel_sum_matrices(&matrix_refs).unwrap();
875
876        // Verify dimensions
877        assert_eq!(result.nrows(), size);
878        assert_eq!(result.ncols(), size);
879
880        // Verify sparsity is maintained (result should still be sparse)
881        let density_result = result.nnz() as f64 / (size * size) as f64;
882        assert!(density_result < 0.1, "Result should maintain sparsity");
883    }
884
885    // ===== Spill / Load tests =====
886
887    #[test]
888    fn test_spill_and_load_roundtrip() {
889        let m = SparseOps::from_triplets_u32(3, 4, vec![
890            (0, 0, 1), (0, 3, 42),
891            (1, 1, 100),
892            (2, 2, 7), (2, 3, 99),
893        ]).unwrap();
894
895        let dir = tempfile::tempdir().unwrap();
896        let path = dir.path().join("matrix.bin");
897
898        SparseOps::spill_to_file(&m, &path).unwrap();
899        let loaded = SparseOps::load_from_file(&path).unwrap();
900
901        assert_eq!(loaded.nrows(), 3);
902        assert_eq!(loaded.ncols(), 4);
903        assert_eq!(loaded.nnz(), 5);
904        assert_eq!(matrix_value(&loaded, 0, 0), 1);
905        assert_eq!(matrix_value(&loaded, 0, 3), 42);
906        assert_eq!(matrix_value(&loaded, 1, 1), 100);
907        assert_eq!(matrix_value(&loaded, 2, 2), 7);
908        assert_eq!(matrix_value(&loaded, 2, 3), 99);
909    }
910
911    #[test]
912    fn test_spill_and_load_empty_matrix() {
913        let m = CsrMatrix::<u32>::zeros(5, 10);
914        let dir = tempfile::tempdir().unwrap();
915        let path = dir.path().join("empty.bin");
916
917        SparseOps::spill_to_file(&m, &path).unwrap();
918        let loaded = SparseOps::load_from_file(&path).unwrap();
919
920        assert_eq!(loaded.nrows(), 5);
921        assert_eq!(loaded.ncols(), 10);
922        assert_eq!(loaded.nnz(), 0);
923    }
924
925    #[test]
926    fn test_spill_and_load_large_values() {
927        let m = SparseOps::from_triplets_u32(1, 2, vec![
928            (0, 0, u32::MAX), (0, 1, u32::MAX - 1),
929        ]).unwrap();
930        let dir = tempfile::tempdir().unwrap();
931        let path = dir.path().join("large.bin");
932
933        SparseOps::spill_to_file(&m, &path).unwrap();
934        let loaded = SparseOps::load_from_file(&path).unwrap();
935
936        assert_eq!(matrix_value(&loaded, 0, 0), u32::MAX);
937        assert_eq!(matrix_value(&loaded, 0, 1), u32::MAX - 1);
938    }
939
940    #[test]
941    fn test_estimate_csr_bytes_nonzero() {
942        let m = SparseOps::from_triplets_u32(10, 10, vec![
943            (0, 0, 1), (5, 5, 2), (9, 9, 3),
944        ]).unwrap();
945        let bytes = SparseOps::estimate_csr_bytes(&m);
946        // At least: 11 row_offsets * 8 + 3 col_indices * 8 + 3 values * 4 = 88 + 24 + 12 = 124
947        assert!(bytes >= 100, "Expected >= 100 bytes, got {}", bytes);
948    }
949}