use crate::core::error::{RedicatError, Result};
use itertools::Itertools;
use nalgebra_sparse::ops::serial::spadd_csr_prealloc;
use nalgebra_sparse::ops::Op;
use nalgebra_sparse::{CooMatrix, CsrMatrix};
use rayon::prelude::*;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use std::io::{Read, Write};
use std::path::Path;
pub struct SparseOps;
impl SparseOps {
pub fn from_triplets_u32(
nrows: usize,
ncols: usize,
triplets: Vec<(usize, usize, u32)>,
) -> Result<CsrMatrix<u32>> {
if nrows == 0 || ncols == 0 {
return Ok(CsrMatrix::zeros(nrows, ncols));
}
if triplets.is_empty() {
return Ok(CsrMatrix::zeros(nrows, ncols));
}
for &(row, col, _) in &triplets {
if row >= nrows || col >= ncols {
return Err(RedicatError::InvalidInput(format!(
"Index ({}, {}) exceeds matrix dimensions ({}, {})",
row, col, nrows, ncols
)));
}
}
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
triplets.into_iter().multiunzip();
let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
.map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
let csr = CsrMatrix::from(&coo);
Ok(csr)
}
pub fn from_triplets(
nrows: usize,
ncols: usize,
triplets: Vec<(usize, usize, u8)>,
) -> Result<CsrMatrix<u8>> {
if nrows == 0 || ncols == 0 {
return Ok(CsrMatrix::zeros(nrows, ncols));
}
if triplets.is_empty() {
return Ok(CsrMatrix::zeros(nrows, ncols));
}
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) =
triplets.into_iter().multiunzip();
let coo = CooMatrix::try_from_triplets(nrows, ncols, row_indices, col_indices, values)
.map_err(|e| RedicatError::SparseMatrix(format!("COO creation failed: {:?}", e)))?;
Ok(CsrMatrix::from(&coo))
}
pub fn add_matrices(a: &CsrMatrix<u32>, b: &CsrMatrix<u32>) -> Result<CsrMatrix<u32>> {
if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
return Err(RedicatError::DimensionMismatch {
expected: format!("{}×{}", a.nrows(), a.ncols()),
actual: format!("{}×{}", b.nrows(), b.ncols()),
});
}
let pattern = nalgebra_sparse::ops::serial::spadd_pattern(a.pattern(), b.pattern());
let mut result =
CsrMatrix::try_from_pattern_and_values(pattern.clone(), vec![0u32; pattern.nnz()])
.map_err(|e| {
RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e))
})?;
spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(a))
.map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(b))
.map_err(|e| RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e)))?;
Ok(result)
}
pub fn parallel_sum_matrices(matrices: &[&CsrMatrix<u32>]) -> Result<CsrMatrix<u32>> {
if matrices.is_empty() {
return Err(RedicatError::EmptyData("No matrices to sum".to_string()));
}
if matrices.len() == 1 {
return Ok(matrices[0].clone());
}
let (nrows, ncols) = (matrices[0].nrows(), matrices[0].ncols());
for matrix in matrices.iter().skip(1) {
if matrix.nrows() != nrows || matrix.ncols() != ncols {
return Err(RedicatError::DimensionMismatch {
expected: format!("{}×{}", nrows, ncols),
actual: format!("{}×{}", matrix.nrows(), matrix.ncols()),
});
}
}
let mut union_pattern = matrices[0].pattern().clone();
for matrix in matrices.iter().skip(1) {
union_pattern =
nalgebra_sparse::ops::serial::spadd_pattern(&union_pattern, matrix.pattern());
}
let nnz = union_pattern.nnz();
let mut result =
CsrMatrix::try_from_pattern_and_values(union_pattern, vec![0u32; nnz]).map_err(
|e| RedicatError::SparseMatrix(format!("Failed to create result matrix: {:?}", e)),
)?;
for matrix in matrices {
spadd_csr_prealloc(1u32, &mut result, 1u32, Op::NoOp(*matrix)).map_err(|e| {
RedicatError::SparseMatrix(format!("Sparse addition failed: {:?}", e))
})?;
}
Ok(result)
}
pub fn filter_columns_u32(
matrix: &CsrMatrix<u32>,
keep_indices: &[usize],
) -> Result<CsrMatrix<u32>> {
let nrows = matrix.nrows();
let new_ncols = keep_indices.len();
if new_ncols == 0 {
return Ok(CsrMatrix::zeros(nrows, 0));
}
let col_map: FxHashMap<usize, usize> = keep_indices
.iter()
.enumerate()
.map(|(new_idx, &old_idx)| (old_idx, new_idx))
.collect();
let mut new_row_offsets = Vec::with_capacity(nrows + 1);
let mut new_col_indices = Vec::new();
let mut new_values = Vec::new();
new_row_offsets.push(0);
for row_idx in 0..nrows {
let row = matrix.row(row_idx);
for (&old_col, &val) in row.col_indices().iter().zip(row.values()) {
if let Some(&new_col) = col_map.get(&old_col) {
new_col_indices.push(new_col);
new_values.push(val);
}
}
new_row_offsets.push(new_col_indices.len());
}
CsrMatrix::try_from_csr_data(
nrows,
new_ncols,
new_row_offsets,
new_col_indices,
new_values,
)
.map_err(|e| {
RedicatError::SparseMatrix(format!("Failed to create filtered matrix: {:?}", e))
})
}
pub fn compute_row_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
(0..matrix.nrows())
.into_par_iter()
.map(|row_idx| {
let row = matrix.row(row_idx);
row.values()
.iter()
.fold(0u64, |acc, &val| acc.saturating_add(val as u64))
.min(u32::MAX as u64) as u32
})
.collect()
}
pub fn compute_masked_row_sums(matrix: &CsrMatrix<u32>, mask: &[bool]) -> Vec<u32> {
let mask_len = mask.len();
if matrix.ncols() != mask_len {
return vec![0; matrix.nrows()];
}
(0..matrix.nrows())
.into_par_iter()
.map(|row_idx| {
let row = matrix.row(row_idx);
row.col_indices()
.iter()
.zip(row.values())
.fold(0u64, |acc, (&col_idx, &val)| {
if mask[col_idx] {
acc.saturating_add(val as u64)
} else {
acc
}
})
.min(u32::MAX as u64) as u32
})
.collect()
}
pub fn compute_col_sums(matrix: &CsrMatrix<u32>) -> Vec<u32> {
let ncols = matrix.ncols();
let chunk_size = std::cmp::max(1, matrix.nrows() / rayon::current_num_threads());
(0..matrix.nrows())
.into_par_iter()
.chunks(chunk_size)
.map(|chunk| {
let mut local_sums = vec![0u64; ncols];
for row_idx in chunk {
let row = matrix.row(row_idx);
for (&col_idx, &val) in row.col_indices().iter().zip(row.values()) {
local_sums[col_idx] = local_sums[col_idx].saturating_add(val as u64);
}
}
local_sums
})
.reduce(
|| vec![0u64; ncols],
|mut acc, local| {
for (i, val) in local.into_iter().enumerate() {
acc[i] = acc[i].saturating_add(val);
}
acc
},
)
.into_iter()
.map(|sum| (sum.min(u32::MAX as u64)) as u32)
.collect()
}
pub fn element_wise_multiply(a: &CsrMatrix<u32>, b: &CsrMatrix<u8>) -> Result<CsrMatrix<u32>> {
if a.nrows() != b.nrows() || a.ncols() != b.ncols() {
return Err(RedicatError::DimensionMismatch {
expected: format!("{}×{}", a.nrows(), a.ncols()),
actual: format!("{}×{}", b.nrows(), b.ncols()),
});
}
let triplets: Vec<(usize, usize, u32)> = (0..a.nrows())
.into_par_iter()
.flat_map(|row_idx| {
let a_row = a.row(row_idx);
let b_row = b.row(row_idx);
let a_cols = a_row.col_indices();
let a_vals = a_row.values();
let b_cols = b_row.col_indices();
let b_vals = b_row.values();
let mut result: SmallVec<[(usize, usize, u32); 32]> = SmallVec::new();
let mut a_idx = 0;
let mut b_idx = 0;
while a_idx < a_cols.len() && b_idx < b_cols.len() {
let a_col = a_cols[a_idx];
let b_col = b_cols[b_idx];
match a_col.cmp(&b_col) {
std::cmp::Ordering::Equal => {
if b_vals[b_idx] > 0 {
result.push((row_idx, a_col, a_vals[a_idx]));
}
a_idx += 1;
b_idx += 1;
}
std::cmp::Ordering::Less => {
a_idx += 1;
}
std::cmp::Ordering::Greater => {
b_idx += 1;
}
}
}
result.into_vec()
})
.collect();
Self::from_triplets_u32(a.nrows(), a.ncols(), triplets)
}
pub fn transpose_u32(matrix: &CsrMatrix<u32>) -> CsrMatrix<u32> {
matrix.transpose()
}
pub fn matrix_vector_multiply(matrix: &CsrMatrix<u32>, vector: &[u32]) -> Result<Vec<u32>> {
if matrix.ncols() != vector.len() {
return Err(RedicatError::DimensionMismatch {
expected: format!("vector length = {}", matrix.ncols()),
actual: format!("vector length = {}", vector.len()),
});
}
let mut result = vec![0u64; matrix.nrows()];
result
.par_iter_mut()
.enumerate()
.for_each(|(row_idx, result_val)| {
let row = matrix.row(row_idx);
*result_val = row.col_indices().iter().zip(row.values()).fold(
0u64,
|acc, (&col_idx, &mat_val)| {
acc.saturating_add((mat_val as u64) * (vector[col_idx] as u64))
},
);
});
Ok(result
.into_iter()
.map(|val| (val.min(u32::MAX as u64)) as u32)
.collect())
}
pub fn get_density_stats(matrix: &CsrMatrix<u32>) -> (f64, usize, usize) {
let total_elements = matrix.nrows() * matrix.ncols();
let nnz = matrix.nnz();
let density = if total_elements > 0 {
nnz as f64 / total_elements as f64
} else {
0.0
};
(density, nnz, total_elements)
}
pub fn spill_to_file(matrix: &CsrMatrix<u32>, path: &Path) -> Result<()> {
let mut file = std::fs::File::create(path).map_err(RedicatError::Io)?;
let (row_offsets, col_indices, values) = matrix.csr_data();
let nrows = matrix.nrows() as u64;
let ncols = matrix.ncols() as u64;
let nnz = matrix.nnz() as u64;
file.write_all(&nrows.to_le_bytes()).map_err(RedicatError::Io)?;
file.write_all(&ncols.to_le_bytes()).map_err(RedicatError::Io)?;
file.write_all(&nnz.to_le_bytes()).map_err(RedicatError::Io)?;
for &offset in row_offsets {
file.write_all(&(offset as u64).to_le_bytes()).map_err(RedicatError::Io)?;
}
for &col in col_indices {
file.write_all(&(col as u64).to_le_bytes()).map_err(RedicatError::Io)?;
}
let value_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
values.as_ptr() as *const u8,
values.len() * std::mem::size_of::<u32>(),
)
};
file.write_all(value_bytes).map_err(RedicatError::Io)?;
file.flush().map_err(RedicatError::Io)?;
Ok(())
}
pub fn load_from_file(path: &Path) -> Result<CsrMatrix<u32>> {
let mut file = std::fs::File::open(path).map_err(RedicatError::Io)?;
let mut buf8 = [0u8; 8];
let read_u64 = |f: &mut std::fs::File, b: &mut [u8; 8]| -> Result<u64> {
f.read_exact(b).map_err(RedicatError::Io)?;
Ok(u64::from_le_bytes(*b))
};
let nrows = read_u64(&mut file, &mut buf8)? as usize;
let ncols = read_u64(&mut file, &mut buf8)? as usize;
let nnz = read_u64(&mut file, &mut buf8)? as usize;
let mut row_offsets = Vec::with_capacity(nrows + 1);
for _ in 0..=nrows {
row_offsets.push(read_u64(&mut file, &mut buf8)? as usize);
}
let mut col_indices = Vec::with_capacity(nnz);
for _ in 0..nnz {
col_indices.push(read_u64(&mut file, &mut buf8)? as usize);
}
let mut value_bytes = vec![0u8; nnz * std::mem::size_of::<u32>()];
file.read_exact(&mut value_bytes).map_err(RedicatError::Io)?;
let values: Vec<u32> = value_bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
CsrMatrix::try_from_csr_data(nrows, ncols, row_offsets, col_indices, values).map_err(
|e| RedicatError::SparseMatrix(format!("Failed to load spilled matrix: {:?}", e)),
)
}
pub fn estimate_csr_bytes(matrix: &CsrMatrix<u32>) -> usize {
let (row_offsets, col_indices, values) = matrix.csr_data();
row_offsets.len() * std::mem::size_of::<usize>()
+ col_indices.len() * std::mem::size_of::<usize>()
+ values.len() * std::mem::size_of::<u32>()
}
}
pub trait SparseMatrixExt<T> {
fn apply_threshold(&self, threshold: T) -> CsrMatrix<T>
where
T: Copy + PartialOrd + Default + nalgebra::Scalar;
}
impl SparseMatrixExt<u32> for CsrMatrix<u32> {
fn apply_threshold(&self, threshold: u32) -> CsrMatrix<u32> {
let triplets: Vec<(usize, usize, u32)> = self
.triplet_iter()
.filter_map(|(row, col, &val)| {
if val >= threshold {
Some((row, col, val))
} else {
None
}
})
.collect();
SparseOps::from_triplets_u32(self.nrows(), self.ncols(), triplets)
.unwrap_or_else(|_| CsrMatrix::zeros(self.nrows(), self.ncols()))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn matrix_value(matrix: &CsrMatrix<u32>, row: usize, col: usize) -> u32 {
let row_view = matrix.row(row);
row_view
.col_indices()
.iter()
.zip(row_view.values())
.find_map(|(&col_idx, &value)| (col_idx == col).then_some(value))
.unwrap_or(0)
}
#[test]
fn test_parallel_sum_two_matrices() {
let m1 = SparseOps::from_triplets_u32(3, 3, vec![
(0, 0, 1), (0, 1, 2),
(1, 1, 3), (1, 2, 4),
(2, 0, 5), (2, 2, 6),
]).unwrap();
let m2 = SparseOps::from_triplets_u32(3, 3, vec![
(0, 0, 10), (0, 2, 20),
(1, 1, 30),
(2, 1, 40), (2, 2, 50),
]).unwrap();
let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]).unwrap();
assert_eq!(matrix_value(&result, 0, 0), 11); assert_eq!(matrix_value(&result, 0, 1), 2); assert_eq!(matrix_value(&result, 0, 2), 20); assert_eq!(matrix_value(&result, 1, 1), 33); assert_eq!(matrix_value(&result, 1, 2), 4); assert_eq!(matrix_value(&result, 2, 0), 5); assert_eq!(matrix_value(&result, 2, 1), 40); assert_eq!(matrix_value(&result, 2, 2), 56); }
#[test]
fn test_parallel_sum_multiple_matrices() {
let m1 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 1), (1, 1, 2)]).unwrap();
let m2 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 3), (0, 1, 4)]).unwrap();
let m3 = SparseOps::from_triplets_u32(2, 2, vec![(1, 0, 5), (1, 1, 6)]).unwrap();
let m4 = SparseOps::from_triplets_u32(2, 2, vec![(0, 1, 7), (1, 0, 8)]).unwrap();
let result = SparseOps::parallel_sum_matrices(&[&m1, &m2, &m3, &m4]).unwrap();
assert_eq!(matrix_value(&result, 0, 0), 4); assert_eq!(matrix_value(&result, 0, 1), 11); assert_eq!(matrix_value(&result, 1, 0), 13); assert_eq!(matrix_value(&result, 1, 1), 8); }
#[test]
fn test_parallel_sum_eight_matrices() {
let matrices: Vec<CsrMatrix<u32>> = (0..8)
.map(|i| {
SparseOps::from_triplets_u32(2, 2, vec![
(0, 0, i + 1),
(1, 1, i + 1),
]).unwrap()
})
.collect();
let matrix_refs: Vec<&CsrMatrix<u32>> = matrices.iter().collect();
let result = SparseOps::parallel_sum_matrices(&matrix_refs).unwrap();
assert_eq!(matrix_value(&result, 0, 0), 36);
assert_eq!(matrix_value(&result, 1, 1), 36);
assert_eq!(matrix_value(&result, 0, 1), 0);
assert_eq!(matrix_value(&result, 1, 0), 0);
}
#[test]
fn test_from_triplets_empty_input() {
let m = SparseOps::from_triplets_u32(5, 5, vec![]).unwrap();
assert_eq!(m.nrows(), 5);
assert_eq!(m.ncols(), 5);
assert_eq!(m.nnz(), 0);
}
#[test]
fn test_from_triplets_zero_dimensions() {
let m = SparseOps::from_triplets_u32(0, 0, vec![]).unwrap();
assert_eq!(m.nrows(), 0);
assert_eq!(m.ncols(), 0);
}
#[test]
fn test_from_triplets_out_of_bounds() {
let result = SparseOps::from_triplets_u32(2, 2, vec![(3, 0, 1)]);
assert!(result.is_err());
}
#[test]
fn test_from_triplets_duplicate_entries_summed() {
let m = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 3), (0, 0, 7)]).unwrap();
assert_eq!(matrix_value(&m, 0, 0), 10);
}
#[test]
fn test_add_matrices_dimension_mismatch() {
let a = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
let b = SparseOps::from_triplets_u32(3, 2, vec![(0, 0, 1)]).unwrap();
assert!(SparseOps::add_matrices(&a, &b).is_err());
}
#[test]
fn test_add_matrices_one_empty() {
let a = SparseOps::from_triplets_u32(3, 3, vec![(0, 0, 5), (2, 2, 10)]).unwrap();
let b = CsrMatrix::<u32>::zeros(3, 3);
let result = SparseOps::add_matrices(&a, &b).unwrap();
assert_eq!(matrix_value(&result, 0, 0), 5);
assert_eq!(matrix_value(&result, 2, 2), 10);
assert_eq!(result.nnz(), 2);
}
#[test]
fn test_filter_columns_keeps_correct_subset() {
let m = SparseOps::from_triplets_u32(2, 4, vec![
(0, 0, 1), (0, 1, 2), (0, 2, 3), (0, 3, 4),
(1, 0, 5), (1, 1, 6), (1, 2, 7), (1, 3, 8),
]).unwrap();
let filtered = SparseOps::filter_columns_u32(&m, &[1, 3]).unwrap();
assert_eq!(filtered.nrows(), 2);
assert_eq!(filtered.ncols(), 2);
assert_eq!(matrix_value(&filtered, 0, 0), 2); assert_eq!(matrix_value(&filtered, 0, 1), 4); assert_eq!(matrix_value(&filtered, 1, 0), 6);
assert_eq!(matrix_value(&filtered, 1, 1), 8);
}
#[test]
fn test_filter_columns_empty_keep() {
let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
let filtered = SparseOps::filter_columns_u32(&m, &[]).unwrap();
assert_eq!(filtered.ncols(), 0);
assert_eq!(filtered.nnz(), 0);
}
#[test]
fn test_filter_columns_preserves_sparsity() {
let m = SparseOps::from_triplets_u32(100, 100, vec![
(0, 10, 1), (50, 50, 2), (99, 99, 3),
]).unwrap();
let keep: Vec<usize> = (0..50).collect();
let filtered = SparseOps::filter_columns_u32(&m, &keep).unwrap();
assert_eq!(filtered.ncols(), 50);
assert_eq!(filtered.nnz(), 1);
assert_eq!(matrix_value(&filtered, 0, 10), 1);
}
#[test]
fn test_compute_row_sums_basic() {
let m = SparseOps::from_triplets_u32(3, 3, vec![
(0, 0, 1), (0, 1, 2), (0, 2, 3),
(1, 1, 10),
(2, 0, 5), (2, 2, 5),
]).unwrap();
assert_eq!(SparseOps::compute_row_sums(&m), vec![6, 10, 10]);
}
#[test]
fn test_compute_row_sums_empty_matrix() {
let m = CsrMatrix::<u32>::zeros(3, 4);
assert_eq!(SparseOps::compute_row_sums(&m), vec![0, 0, 0]);
}
#[test]
fn test_compute_col_sums_basic() {
let m = SparseOps::from_triplets_u32(3, 3, vec![
(0, 0, 1), (1, 0, 2), (2, 0, 3),
(0, 2, 10), (1, 2, 20),
]).unwrap();
assert_eq!(SparseOps::compute_col_sums(&m), vec![6, 0, 30]);
}
#[test]
fn test_compute_col_sums_empty() {
let m = CsrMatrix::<u32>::zeros(2, 5);
assert_eq!(SparseOps::compute_col_sums(&m), vec![0, 0, 0, 0, 0]);
}
#[test]
fn test_compute_masked_row_sums_basic() {
let m = SparseOps::from_triplets_u32(2, 4, vec![
(0, 0, 10), (0, 1, 20), (0, 2, 30), (0, 3, 40),
(1, 0, 1), (1, 1, 2), (1, 2, 3), (1, 3, 4),
]).unwrap();
let mask = vec![true, false, true, false];
let sums = SparseOps::compute_masked_row_sums(&m, &mask);
assert_eq!(sums, vec![40, 4]); }
#[test]
fn test_compute_masked_row_sums_all_false() {
let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 99)]).unwrap();
let mask = vec![false, false, false];
assert_eq!(SparseOps::compute_masked_row_sums(&m, &mask), vec![0, 0]);
}
#[test]
fn test_compute_masked_row_sums_wrong_length() {
let m = SparseOps::from_triplets_u32(2, 3, vec![(0, 0, 1)]).unwrap();
let mask = vec![true, false]; assert_eq!(SparseOps::compute_masked_row_sums(&m, &mask), vec![0, 0]);
}
#[test]
fn test_element_wise_multiply_basic() {
let a = SparseOps::from_triplets_u32(2, 2, vec![
(0, 0, 10), (0, 1, 20), (1, 0, 30), (1, 1, 40),
]).unwrap();
let b = SparseOps::from_triplets(2, 2, vec![
(0, 0, 1), (0, 1, 0), (1, 1, 1),
]).unwrap();
let result = SparseOps::element_wise_multiply(&a, &b).unwrap();
assert_eq!(matrix_value(&result, 0, 0), 10);
assert_eq!(matrix_value(&result, 0, 1), 0); assert_eq!(matrix_value(&result, 1, 0), 0); assert_eq!(matrix_value(&result, 1, 1), 40);
}
#[test]
fn test_element_wise_multiply_dimension_mismatch() {
let a = SparseOps::from_triplets_u32(2, 3, vec![]).unwrap();
let b = SparseOps::from_triplets(3, 2, vec![]).unwrap();
assert!(SparseOps::element_wise_multiply(&a, &b).is_err());
}
#[test]
fn test_transpose_basic() {
let m = SparseOps::from_triplets_u32(2, 3, vec![
(0, 0, 1), (0, 2, 2), (1, 1, 3),
]).unwrap();
let t = SparseOps::transpose_u32(&m);
assert_eq!(t.nrows(), 3);
assert_eq!(t.ncols(), 2);
assert_eq!(matrix_value(&t, 0, 0), 1);
assert_eq!(matrix_value(&t, 2, 0), 2);
assert_eq!(matrix_value(&t, 1, 1), 3);
}
#[test]
fn test_matrix_vector_multiply() {
let m = SparseOps::from_triplets_u32(2, 3, vec![
(0, 0, 1), (0, 1, 2), (0, 2, 3),
(1, 0, 4), (1, 1, 5), (1, 2, 6),
]).unwrap();
let v = vec![1, 10, 100];
let result = SparseOps::matrix_vector_multiply(&m, &v).unwrap();
assert_eq!(result, vec![321, 654]);
}
#[test]
fn test_matrix_vector_multiply_dimension_mismatch() {
let m = SparseOps::from_triplets_u32(2, 3, vec![]).unwrap();
assert!(SparseOps::matrix_vector_multiply(&m, &[1, 2]).is_err());
}
#[test]
fn test_density_stats() {
let m = SparseOps::from_triplets_u32(10, 10, vec![
(0, 0, 1), (5, 5, 2), (9, 9, 3),
]).unwrap();
let (density, nnz, total) = SparseOps::get_density_stats(&m);
assert_eq!(nnz, 3);
assert_eq!(total, 100);
assert!((density - 0.03).abs() < 1e-10);
}
#[test]
fn test_apply_threshold() {
let m = SparseOps::from_triplets_u32(3, 3, vec![
(0, 0, 1), (0, 1, 5), (1, 1, 10), (2, 2, 3),
]).unwrap();
let filtered = m.apply_threshold(5);
assert_eq!(matrix_value(&filtered, 0, 0), 0); assert_eq!(matrix_value(&filtered, 0, 1), 5); assert_eq!(matrix_value(&filtered, 1, 1), 10); assert_eq!(matrix_value(&filtered, 2, 2), 0); }
#[test]
fn test_parallel_sum_single_matrix() {
let m = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 42), (1, 1, 24)]).unwrap();
let result = SparseOps::parallel_sum_matrices(&[&m]).unwrap();
assert_eq!(matrix_value(&result, 0, 0), 42);
assert_eq!(matrix_value(&result, 1, 1), 24);
}
#[test]
fn test_parallel_sum_preserves_sparsity() {
let m1 = SparseOps::from_triplets_u32(100, 100, vec![
(0, 0, 1), (10, 10, 2), (50, 50, 3)
]).unwrap();
let m2 = SparseOps::from_triplets_u32(100, 100, vec![
(0, 0, 10), (20, 20, 20), (50, 50, 30)
]).unwrap();
let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]).unwrap();
assert!(result.nnz() <= 6);
assert_eq!(matrix_value(&result, 0, 0), 11);
assert_eq!(matrix_value(&result, 10, 10), 2);
assert_eq!(matrix_value(&result, 20, 20), 20);
assert_eq!(matrix_value(&result, 50, 50), 33);
}
#[test]
fn test_parallel_sum_dimension_mismatch() {
let m1 = SparseOps::from_triplets_u32(2, 2, vec![(0, 0, 1)]).unwrap();
let m2 = SparseOps::from_triplets_u32(3, 3, vec![(0, 0, 1)]).unwrap();
let result = SparseOps::parallel_sum_matrices(&[&m1, &m2]);
assert!(result.is_err());
}
#[test]
fn test_parallel_sum_empty_list() {
let result = SparseOps::parallel_sum_matrices(&[]);
assert!(result.is_err());
}
#[test]
fn test_parallel_sum_large_scale() {
let n_matrices = 8;
let size = 1000;
let density = 0.01;
let n_nonzeros = (size as f64 * size as f64 * density) as usize;
let matrices: Vec<CsrMatrix<u32>> = (0..n_matrices)
.map(|matrix_idx| {
let triplets: Vec<(usize, usize, u32)> = (0..n_nonzeros)
.map(|i| {
let row = (i * 7 + matrix_idx * 13) % size;
let col = (i * 11 + matrix_idx * 17) % size;
(row, col, 1)
})
.collect();
SparseOps::from_triplets_u32(size, size, triplets).unwrap()
})
.collect();
let matrix_refs: Vec<&CsrMatrix<u32>> = matrices.iter().collect();
let result = SparseOps::parallel_sum_matrices(&matrix_refs).unwrap();
assert_eq!(result.nrows(), size);
assert_eq!(result.ncols(), size);
let density_result = result.nnz() as f64 / (size * size) as f64;
assert!(density_result < 0.1, "Result should maintain sparsity");
}
#[test]
fn test_spill_and_load_roundtrip() {
let m = SparseOps::from_triplets_u32(3, 4, vec![
(0, 0, 1), (0, 3, 42),
(1, 1, 100),
(2, 2, 7), (2, 3, 99),
]).unwrap();
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("matrix.bin");
SparseOps::spill_to_file(&m, &path).unwrap();
let loaded = SparseOps::load_from_file(&path).unwrap();
assert_eq!(loaded.nrows(), 3);
assert_eq!(loaded.ncols(), 4);
assert_eq!(loaded.nnz(), 5);
assert_eq!(matrix_value(&loaded, 0, 0), 1);
assert_eq!(matrix_value(&loaded, 0, 3), 42);
assert_eq!(matrix_value(&loaded, 1, 1), 100);
assert_eq!(matrix_value(&loaded, 2, 2), 7);
assert_eq!(matrix_value(&loaded, 2, 3), 99);
}
#[test]
fn test_spill_and_load_empty_matrix() {
let m = CsrMatrix::<u32>::zeros(5, 10);
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("empty.bin");
SparseOps::spill_to_file(&m, &path).unwrap();
let loaded = SparseOps::load_from_file(&path).unwrap();
assert_eq!(loaded.nrows(), 5);
assert_eq!(loaded.ncols(), 10);
assert_eq!(loaded.nnz(), 0);
}
#[test]
fn test_spill_and_load_large_values() {
let m = SparseOps::from_triplets_u32(1, 2, vec![
(0, 0, u32::MAX), (0, 1, u32::MAX - 1),
]).unwrap();
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("large.bin");
SparseOps::spill_to_file(&m, &path).unwrap();
let loaded = SparseOps::load_from_file(&path).unwrap();
assert_eq!(matrix_value(&loaded, 0, 0), u32::MAX);
assert_eq!(matrix_value(&loaded, 0, 1), u32::MAX - 1);
}
#[test]
fn test_estimate_csr_bytes_nonzero() {
let m = SparseOps::from_triplets_u32(10, 10, vec![
(0, 0, 1), (5, 5, 2), (9, 9, 3),
]).unwrap();
let bytes = SparseOps::estimate_csr_bytes(&m);
assert!(bytes >= 100, "Expected >= 100 bytes, got {}", bytes);
}
}