redicat_lib/core/
sparse.rs

1//! High-performance sparse matrix utilities shared across REDICAT
2
3use crate::core::error::{RedicatError, Result};
4use itertools::Itertools;
5use nalgebra_sparse::ops::serial::spadd_csr_prealloc;
6use nalgebra_sparse::ops::Op;
7use nalgebra_sparse::{CooMatrix, CsrMatrix};
8use rayon::prelude::*;
9use rustc_hash::FxHashMap;
10use smallvec::SmallVec;
11
12pub struct SparseOps;
13
14impl SparseOps {
15    /// Create CSR matrix from COO format using nalgebra_sparse native conversion
16    pub fn from_triplets_u32(
17        nrows: usize,
18        ncols: usize,
19        triplets: Vec<(usize, usize, u32)>,
20    ) -> Result<CsrMatrix<u32>> {
21        if nrows == 0 || ncols == 0 {
22            return Ok(CsrMatrix::zeros(nrows, ncols));
23        }
24
25        if triplets.is_empty() {
26            return Ok(CsrMatrix::zeros(nrows, ncols));
27        }
28
29        // Validate indices
30        for &(row, col, _) in &triplets {
31            if row >= nrows || col >= ncols {
32                return Err(RedicatError::InvalidInput(format!(
33                    "Index ({}, {}) exceeds matrix dimensions ({}, {})",
34                    row, col, nrows, ncols
35                )));
36            }
37        }
38
39        // Use COO format first, then convert to CSR using native nalgebra_sparse
40        let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
41            triplets.into_iter().multiunzip();
42
43        let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
44            .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
45
46        // Use native conversion from COO to CSR
47        let csr = CsrMatrix::from(&coo);
48        Ok(csr)
49    }
50
51    /// Create CSR matrix from triplets with u8 values
52    pub fn from_triplets(
53        nrows: usize,
54        ncols: usize,
55        triplets: Vec<(usize, usize, u8)>,
56    ) -> Result<CsrMatrix<u8>> {
57        if nrows == 0 || ncols == 0 {
58            return Ok(CsrMatrix::zeros(nrows, ncols));
59        }
60
61        if triplets.is_empty() {
62            return Ok(CsrMatrix::zeros(nrows, ncols));
63        }
64
65        let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
66            triplets.into_iter().multiunzip();
67
68        let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
69            .map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
70
71        Ok(CsrMatrix::from(&coo))
72    }
73
74    /// Highly optimized matrix addition using nalgebra_sparse's spadd operation
75    pub fn add_matrices(a: &CsrMatrix<u32>, b: &CsrMatrix<u32>) -> Result<CsrMatrix<u32>> {
76        if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
77            return Err(RedicatError::DimensionMismatch {
78                expected: format!("{}×{}", a.nrows(), a.ncols()),
79                actual: format!("{}×{}", b.nrows(), b.ncols()),
80            });
81        }
82
83        // Use nalgebra_sparse's native sparse addition pattern computation
84        let pattern = nalgebra_sparse::ops::serial::spadd_pattern(a.pattern(), b.pattern());
85
86        // Pre-allocate result matrix with computed pattern
87        let mut result =
88            CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![0u32; pattern.nnz()])
89                .map_err(|e| {
90                    RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e))
91                })?;
92
93        // Use native sparse addition operation with correct API
94        // API signature: spadd_csr_prealloc(beta, C, alpha, Op<A>)
95        spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(a))
96            .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
97
98        spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(b))
99            .map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
100
101        Ok(result)
102    }
103
104    /// Optimized column filtering using sparsity pattern operations
105    pub fn filter_columns_u32(
106        matrix: &CsrMatrix<u32>,
107        keep_indices: &[usize],
108    ) -> Result<CsrMatrix<u32>> {
109        let nrows = matrix.nrows();
110        let new_ncols = keep_indices.len();
111
112        if new_ncols == 0 {
113            return Ok(CsrMatrix::zeros(nrows, 0));
114        }
115
116        // Create efficient column mapping using FxHashMap for better performance
117        let col_map: FxHashMap<usize, usize> = keep_indices
118            .iter()
119            .enumerate()
120            .map(|(new_idx, &old_idx)| (old_idx, new_idx))
121            .collect();
122
123        // Build new sparsity pattern efficiently
124        let mut new_row_offsets = Vec::with_capacity(nrows + 1);
125        let mut new_col_indices = Vec::new();
126        let mut new_values = Vec::new();
127
128        new_row_offsets.push(0);
129
130        for row_idx in 0..nrows {
131            let row = matrix.row(row_idx);
132
133            for (&old_col, &val) in row.col_indices().iter().zip(row.values()) {
134                if let Some(&new_col) = col_map.get(&old_col) {
135                    new_col_indices.push(new_col);
136                    new_values.push(val);
137                }
138            }
139
140            new_row_offsets.push(new_col_indices.len());
141        }
142
143        // Create CSR matrix directly using nalgebra_sparse
144        CsrMatrix::try_from_csr_data(
145            nrows,
146            new_ncols,
147            new_row_offsets,
148            new_col_indices,
149            new_values,
150        )
151        .map_err(|e| {
152            RedicatError::SparseMatrix(format!("Failed to create filtered matrix: {:?}", e))
153        })
154    }
155
156    /// Optimized row sums using native CSR structure access
157    pub fn compute_row_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
158        (0..matrix.nrows())
159            .into_par_iter()
160            .map(|row_idx| {
161                let row = matrix.row(row_idx);
162                row.values()
163                    .iter()
164                    .fold(0u64, |acc, &val| acc.saturating_add(val as u64))
165                    .min(u32::MAX as u64) as u32
166            })
167            .collect()
168    }
169
170    /// Optimized column sums using parallel reduction over CSR structure
171    pub fn compute_col_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
172        let ncols = matrix.ncols();
173
174        // Use parallel reduction with chunked processing
175        let chunk_size = std::cmp::max(1, matrix.nrows() / rayon::current_num_threads());
176
177        (0..matrix.nrows())
178            .into_par_iter()
179            .chunks(chunk_size)
180            .map(|chunk| {
181                let mut local_sums = vec![0u64; ncols];
182                for row_idx in chunk {
183                    let row = matrix.row(row_idx);
184                    for (&col_idx, &val) in row.col_indices().iter().zip(row.values()) {
185                        local_sums[col_idx] = local_sums[col_idx].saturating_add(val as u64);
186                    }
187                }
188                local_sums
189            })
190            .reduce(
191                || vec![0u64; ncols],
192                |mut acc, local| {
193                    for (i, val) in local.into_iter().enumerate() {
194                        acc[i] = acc[i].saturating_add(val);
195                    }
196                    acc
197                },
198            )
199            .into_iter()
200            .map(|sum| (sum.min(u32::MAX as u64)) as u32)
201            .collect()
202    }
203
204    /// Element-wise multiplication using optimized two-pointer algorithm for sparse matrices
205    ///
206    /// This implementation uses a sorted two-pointer merge algorithm which is significantly
207    /// faster than HashMap-based intersection for sparse matrices. CSR format guarantees
208    /// column indices are already sorted within each row, allowing O(nnz_a + nnz_b) complexity
209    /// instead of O(nnz_a * log(nnz_b)) with HashMap lookups.
210    pub fn element_wise_multiply(a: &CsrMatrix<u32>, b: &CsrMatrix<u8>) -> Result<CsrMatrix<u32>> {
211        if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
212            return Err(RedicatError::DimensionMismatch {
213                expected: format!("{}×{}", a.nrows(), a.ncols()),
214                actual: format!("{}×{}", b.nrows(), b.ncols()),
215            });
216        }
217
218        // Use parallel processing with two-pointer algorithm for element-wise multiplication
219        // SmallVec optimizes for typical sparse row sizes (most rows have < 32 non-zero elements)
220        let triplets: Vec<(usize, usize, u32)> = (0..a.nrows())
221            .into_par_iter()
222            .flat_map(|row_idx| {
223                let a_row = a.row(row_idx);
224                let b_row = b.row(row_idx);
225
226                let a_cols = a_row.col_indices();
227                let a_vals = a_row.values();
228                let b_cols = b_row.col_indices();
229                let b_vals = b_row.values();
230
231                // Two-pointer algorithm for sorted sparse row intersection
232                // CSR format guarantees column indices are sorted, so we can use merge-like algorithm
233                let mut result: SmallVec<[(usize, usize, u32); 32]> = SmallVec::new();
234                let mut a_idx = 0;
235                let mut b_idx = 0;
236
237                while a_idx < a_cols.len() && b_idx < b_cols.len() {
238                    let a_col = a_cols[a_idx];
239                    let b_col = b_cols[b_idx];
240
241                    match a_col.cmp(&b_col) {
242                        std::cmp::Ordering::Equal => {
243                            // Column indices match - this is an intersection point
244                            if b_vals[b_idx] > 0 {
245                                result.push((row_idx, a_col, a_vals[a_idx]));
246                            }
247                            a_idx += 1;
248                            b_idx += 1;
249                        }
250                        std::cmp::Ordering::Less => {
251                            // a has a column that b doesn't have yet
252                            a_idx += 1;
253                        }
254                        std::cmp::Ordering::Greater => {
255                            // b has a column that a doesn't have yet
256                            b_idx += 1;
257                        }
258                    }
259                }
260
261                // Convert SmallVec to Vec for flat_map compatibility
262                result.into_vec()
263            })
264            .collect();
265
266        Self::from_triplets_u32(a.nrows(), a.ncols(), triplets)
267    }
268
269    /// Transpose operation using nalgebra_sparse native transpose
270    pub fn transpose_u32(matrix: &CsrMatrix<u32>) -> CsrMatrix<u32> {
271        matrix.transpose()
272    }
273
274    /// Matrix-vector multiplication using native spmv operation
275    pub fn matrix_vector_multiply(matrix: &CsrMatrix<u32>, vector: &[u32]) -> Result<Vec<u32>> {
276        if matrix.ncols() != vector.len() {
277            return Err(RedicatError::DimensionMismatch {
278                expected: format!("vector length = {}", matrix.ncols()),
279                actual: format!("vector length = {}", vector.len()),
280            });
281        }
282
283        let mut result = vec![0u64; matrix.nrows()];
284
285        // Use parallel processing for matrix-vector multiplication
286        result
287            .par_iter_mut()
288            .enumerate()
289            .for_each(|(row_idx, result_val)| {
290                let row = matrix.row(row_idx);
291                *result_val = row.col_indices().iter().zip(row.values()).fold(
292                    0u64,
293                    |acc, (&col_idx, &mat_val)| {
294                        acc.saturating_add((mat_val as u64) * (vector[col_idx] as u64))
295                    },
296                );
297            });
298
299        Ok(result
300            .into_iter()
301            .map(|val| (val.min(u32::MAX as u64)) as u32)
302            .collect())
303    }
304
305    /// Get matrix density statistics
306    pub fn get_density_stats(matrix: &CsrMatrix<u32>) -> (f64, usize, usize) {
307        let total_elements = matrix.nrows() * matrix.ncols();
308        let nnz = matrix.nnz();
309        let density = if total_elements > 0 {
310            nnz as f64 / total_elements as f64
311        } else {
312            0.0
313        };
314        (density, nnz, total_elements)
315    }
316}
317
318/// Trait extension for additional sparse matrix operations
319pub trait SparseMatrixExt<T> {
320    fn apply_threshold(&self, threshold: T) -> CsrMatrix<T>
321    where
322        T: Copy + PartialOrd + Default + nalgebra::Scalar;
323}
324
325impl SparseMatrixExt<u32> for CsrMatrix<u32> {
326    /// Apply threshold to sparse matrix values
327    fn apply_threshold(&self, threshold: u32) -> CsrMatrix<u32> {
328        let triplets: Vec<(usize, usize, u32)> = self
329            .triplet_iter()
330            .filter_map(|(row, col, &val)| {
331                if val >= threshold {
332                    Some((row, col, val))
333                } else {
334                    None
335                }
336            })
337            .collect();
338
339        SparseOps::from_triplets_u32(self.nrows(), self.ncols(), triplets)
340            .unwrap_or_else(|_| CsrMatrix::zeros(self.nrows(), self.ncols()))
341    }
342}