use std::collections::HashMap;
use std::iter::Sum;
use std::ops::AddAssign;
use super::{
BatchMatrixMean, BatchMatrixVariance, MatrixMinMax, MatrixNonZero, MatrixSum, MatrixVariance,
};
use crate::sparse::MatrixNTop;
use crate::utils::Normalize;
use crate::utils::{BatchIdentifier, Log1P};
use anyhow::{anyhow, Ok};
use nalgebra_sparse::CsrMatrix;
use num_traits::{Float, NumCast, PrimInt, Unsigned, Zero};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;
use single_utilities::traits::{FloatOpsTS, NumericOps};
use single_utilities::types::Direction;
const PARALLEL_THRESHOLD: usize = 200_000;
const CHUNK_SIZE: usize = 512;
impl<M: NumericOps> MatrixNonZero for CsrMatrix<M> {
fn nonzero_col<T>(&self) -> anyhow::Result<Vec<T>>
where
T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync,
{
let n_cols = self.ncols();
let col_indices = self.col_indices();
let total_nnz = col_indices.len();
if total_nnz == 0 || n_cols == 0 {
return Ok(vec![T::zero(); n_cols]);
}
if let Some(&max_col) = col_indices.iter().max() {
if max_col >= n_cols {
return Err(anyhow::anyhow!(
"Invalid column index {} exceeds matrix column count {}",
max_col,
n_cols
));
}
}
if total_nnz < PARALLEL_THRESHOLD {
let mut result = vec![T::zero(); n_cols];
for &col_idx in col_indices {
result[col_idx] += T::one();
}
Ok(result)
} else {
let result = col_indices
.par_chunks(8192)
.map(|chunk| {
let mut local_counts = vec![T::zero(); n_cols];
for &col_idx in chunk {
local_counts[col_idx] += T::one();
}
local_counts
})
.reduce(
|| vec![T::zero(); n_cols],
|mut acc, local| {
for (i, count) in local.into_iter().enumerate() {
acc[i] += count;
}
acc
},
);
Ok(result)
}
}
fn nonzero_row<T>(&self) -> anyhow::Result<Vec<T>>
where
T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync,
{
let row_offsets = self.row_offsets();
let n_rows = self.nrows();
if n_rows == 0 {
return Ok(Vec::new());
}
if row_offsets.len() != n_rows + 1 {
return Err(anyhow::anyhow!(
"Invalid row offsets: expected {} elements, got {}",
n_rows + 1,
row_offsets.len()
));
}
if n_rows < PARALLEL_THRESHOLD {
let mut result = Vec::with_capacity(n_rows);
for row in 0..n_rows {
let count = row_offsets[row + 1] - row_offsets[row];
result.push(T::from(count).ok_or_else(|| {
anyhow::anyhow!("Count {} exceeds target type capacity", count)
})?);
}
Ok(result)
} else {
let result: Result<Vec<T>, anyhow::Error> = (0..n_rows)
.into_par_iter()
.map(|row| {
let count = row_offsets[row + 1] - row_offsets[row];
T::from(count).ok_or_else(|| {
anyhow::anyhow!("Count {} exceeds target type capacity", count)
})
})
.collect();
result
}
}
fn nonzero_col_chunk<T>(&self, reference: &mut [T]) -> anyhow::Result<()>
where
T: PrimInt + Unsigned + Zero + AddAssign,
{
for &col_index in self.col_indices() {
if col_index < reference.len() {
reference[col_index] += T::one();
}
}
Ok(())
}
fn nonzero_row_chunk<T>(&self, reference: &mut [T]) -> anyhow::Result<()>
where
T: PrimInt + Unsigned + Zero + AddAssign,
{
for (i, window) in self.row_offsets().windows(2).enumerate() {
let count = window[1]
.checked_sub(window[0])
.ok_or_else(|| anyhow!("Subtraction overflow"))?;
let count_transformed =
T::from(count).ok_or_else(|| anyhow!("Failed to convert to target type"))?;
if i < reference.len() {
reference[i] += count_transformed;
}
}
Ok(())
}
fn nonzero_col_masked<T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
where
T: PrimInt + Unsigned + Zero + AddAssign,
{
if mask.len() < self.nrows() {
return Err(anyhow::anyhow!(
"Mask length ({}) is less than number of rows ({})",
mask.len(),
self.nrows()
));
}
let mut result = vec![T::zero(); self.ncols()];
for row in 0..self.nrows() {
if !mask[row] {
continue;
}
let row_start = self.row_offsets()[row];
let row_end = self.row_offsets()[row + 1];
for idx in row_start..row_end {
let col = self.col_indices()[idx];
result[col] += T::one();
}
}
Ok(result)
}
fn nonzero_row_masked<T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
where
T: PrimInt + Unsigned + Zero + AddAssign + Send + Sync,
{
if mask.len() < self.ncols() {
return Err(anyhow::anyhow!(
"Mask length ({}) is less than number of columns ({})",
mask.len(),
self.ncols()
));
}
let n_rows = self.nrows();
let row_offsets = self.row_offsets();
let col_indices = self.col_indices();
if n_rows == 0 {
return Ok(Vec::new());
}
let masked_count = mask.iter().filter(|&&m| m).count();
if masked_count == 0 {
return Ok(vec![T::zero(); n_rows]);
}
let total_nnz = self.nnz();
if total_nnz < PARALLEL_THRESHOLD {
let mut result = Vec::with_capacity(n_rows);
for row in 0..n_rows {
let start = row_offsets[row];
let end = row_offsets[row + 1];
let mut count = T::zero();
for idx in start..end {
if mask[col_indices[idx]] {
count += T::one();
}
}
result.push(count);
}
Ok(result)
} else {
let counts: Vec<T> = (0..n_rows)
.into_par_iter()
.map(|row| {
let start = row_offsets[row];
let end = row_offsets[row + 1];
let mut count = T::zero();
for idx in start..end {
if mask[col_indices[idx]] {
count += T::one();
}
}
count
})
.collect();
Ok(counts)
}
}
}
impl<M: NumericOps + Send + Sync> MatrixSum for CsrMatrix<M> {
type Item = M;
fn sum_col<T>(&self) -> anyhow::Result<Vec<T>>
where
T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync,
Self::Item: NumCast,
{
let n_cols = self.ncols();
let col_indices = self.col_indices();
let values = self.values();
let total_nnz = values.len();
if total_nnz == 0 || n_cols == 0 {
return Ok(vec![T::zero(); n_cols]);
}
if total_nnz < PARALLEL_THRESHOLD {
let mut result = vec![T::zero(); n_cols];
for chunk_start in (0..total_nnz).step_by(CHUNK_SIZE) {
let chunk_end = (chunk_start + CHUNK_SIZE).min(total_nnz);
for idx in chunk_start..chunk_end {
result[col_indices[idx]] += T::from(values[idx]).unwrap();
}
}
Ok(result)
} else {
let result = (0..total_nnz)
.into_par_iter()
.chunks(8192)
.map(|chunk_indices| {
let mut local_sums = vec![T::zero(); n_cols];
for idx in chunk_indices {
if idx < total_nnz {
local_sums[col_indices[idx]] += T::from(values[idx]).unwrap();
}
}
local_sums
})
.reduce(
|| vec![T::zero(); n_cols],
|mut acc, local| {
for (i, val) in local.into_iter().enumerate() {
acc[i] += val;
}
acc
},
);
Ok(result)
}
}
fn sum_row<T>(&self) -> anyhow::Result<Vec<T>>
where
T: Float + NumCast + AddAssign + std::iter::Sum + Send + Sync,
Self::Item: NumCast,
{
let nrows = self.nrows();
let values = self.values();
let row_offsets = self.row_offsets();
if nrows == 0 {
return Ok(Vec::new());
}
let total_nnz = values.len();
if total_nnz < PARALLEL_THRESHOLD {
let mut result = Vec::with_capacity(nrows);
const ROW_CHUNK: usize = 64;
for row_chunk in (0..nrows).step_by(ROW_CHUNK) {
let chunk_end = (row_chunk + ROW_CHUNK).min(nrows);
for row in row_chunk..chunk_end {
let start = row_offsets[row];
let end = row_offsets[row + 1];
let row_values = &values[start..end];
let sum = if row_values.len() < 16 {
row_values.iter().map(|&v| T::from(v).unwrap()).sum::<T>()
} else {
let mut sum = T::zero();
let chunks = row_values.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
sum += T::from(chunk[0]).unwrap();
sum += T::from(chunk[1]).unwrap();
sum += T::from(chunk[2]).unwrap();
sum += T::from(chunk[3]).unwrap();
}
for &val in remainder {
sum += T::from(val).unwrap();
}
sum
};
result.push(sum);
}
}
Ok(result)
} else {
let sums: Vec<T> = (0..nrows)
.into_par_iter()
.map(|row| {
let start = row_offsets[row];
let end = row_offsets[row + 1];
let row_values = &values[start..end];
if row_values.len() < 32 {
row_values.iter().map(|&v| T::from(v).unwrap()).sum::<T>()
} else {
let mut sum = T::zero();
for chunk in row_values.chunks(8) {
let chunk_sum: T = chunk.iter().map(|&v| T::from(v).unwrap()).sum();
sum += chunk_sum;
}
sum
}
})
.collect();
Ok(sums)
}
}
fn sum_col_chunk<T>(&self, reference: &mut [T]) -> anyhow::Result<()>
where
T: Float + NumCast + AddAssign + std::iter::Sum,
Self::Item: NumCast,
{
for (&col_index, &value) in self.col_indices().iter().zip(self.values().iter()) {
if col_index < reference.len() {
reference[col_index] += T::from(value).unwrap();
}
}
Ok(())
}
fn sum_row_chunk<T>(&self, reference: &mut [T]) -> anyhow::Result<()>
where
T: Float + NumCast + AddAssign + std::iter::Sum,
Self::Item: NumCast,
{
for (row, row_vec) in self.row_iter().enumerate() {
reference[row] = row_vec.values().iter().map(|&v| T::from(v).unwrap()).sum();
}
Ok(())
}
fn sum_col_masked<T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
where
T: Float + NumCast + AddAssign + Sum + Send + Sync,
{
if mask.len() < self.nrows() {
return Err(anyhow::anyhow!(
"Mask length ({}) is less than number of rows ({})",
mask.len(),
self.nrows()
));
}
let n_cols = self.ncols();
let row_offsets = self.row_offsets();
let col_indices = self.col_indices();
let values = self.values();
let masked_count = mask.iter().filter(|&&m| m).count();
if masked_count == 0 {
return Ok(vec![T::zero(); n_cols]);
}
if masked_count < PARALLEL_THRESHOLD {
let mut result = vec![T::zero(); n_cols];
for (row, &is_included) in mask.iter().enumerate() {
if is_included {
let start = row_offsets[row];
let end = row_offsets[row + 1];
for idx in start..end {
result[col_indices[idx]] += T::from(values[idx]).unwrap();
}
}
}
Ok(result)
} else {
let result = (0..mask.len())
.into_par_iter()
.chunks(256)
.map(|row_chunk| {
let mut local_sums = vec![T::zero(); n_cols];
for row in row_chunk {
if mask[row] {
let start = row_offsets[row];
let end = row_offsets[row + 1];
for idx in start..end {
local_sums[col_indices[idx]] += T::from(values[idx]).unwrap();
}
}
}
local_sums
})
.reduce(
|| vec![T::zero(); n_cols],
|mut acc, local| {
for (i, val) in local.into_iter().enumerate() {
acc[i] += val;
}
acc
},
);
Ok(result)
}
}
fn sum_row_masked<T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
where
T: Float + NumCast + AddAssign + Sum + Send + Sync,
{
if mask.len() < self.ncols() {
return Err(anyhow::anyhow!(
"Mask length ({}) is less than number of columns ({})",
mask.len(),
self.ncols()
));
}
let n_rows = self.nrows();
let row_offsets = self.row_offsets();
let col_indices = self.col_indices();
let values = self.values();
if n_rows == 0 {
return Ok(Vec::new());
}
let masked_count = mask.iter().filter(|&&m| m).count();
if masked_count == 0 {
return Ok(vec![T::zero(); n_rows]);
}
let total_nnz = self.nnz();
if total_nnz < PARALLEL_THRESHOLD {
let mut result = Vec::with_capacity(n_rows);
for row in 0..n_rows {
let start = row_offsets[row];
let end = row_offsets[row + 1];
let mut sum = T::zero();
for idx in start..end {
if mask[col_indices[idx]] {
sum += T::from(values[idx]).unwrap();
}
}
result.push(sum);
}
Ok(result)
} else {
let sums: Vec<T> = (0..n_rows)
.into_par_iter()
.map(|row| {
let start = row_offsets[row];
let end = row_offsets[row + 1];
let mut sum = T::zero();
for idx in start..end {
if mask[col_indices[idx]] {
sum += T::from(values[idx]).unwrap();
}
}
sum
})
.collect();
Ok(sums)
}
}
fn sum_col_squared<T>(&self) -> anyhow::Result<Vec<T>>
where
T: Float + NumCast + AddAssign + Sum + Send + Sync,
{
let n_cols = self.ncols();
let col_indices = self.col_indices();
let values = self.values();
let total_nnz = values.len();
if total_nnz == 0 || n_cols == 0 {
return Ok(vec![T::zero(); n_cols]);
}
if total_nnz < PARALLEL_THRESHOLD {
let mut result = vec![T::zero(); n_cols];
for (&col_idx, &value) in col_indices.iter().zip(values.iter()) {
let val = T::from(value).unwrap();
result[col_idx] += val * val;
}
Ok(result)
} else {
let result = (0..total_nnz)
.into_par_iter()
.chunks(8192)
.map(|chunk_indices| {
let mut local_sums = vec![T::zero(); n_cols];
for idx in chunk_indices {
if idx < total_nnz {
let val = T::from(values[idx]).unwrap();
local_sums[col_indices[idx]] += val * val;
}
}
local_sums
})
.reduce(
|| vec![T::zero(); n_cols],
|mut acc, local| {
for (i, val) in local.into_iter().enumerate() {
acc[i] += val;
}
acc
},
);
Ok(result)
}
}
fn sum_row_squared<T>(&self) -> anyhow::Result<Vec<T>>
where
T: Float + NumCast + AddAssign + Sum,
{
let mut result = vec![T::zero(); self.ncols()];
for (row, _, &value) in self.triplet_iter() {
let val = T::from(value).unwrap();
result[row] += val * val;
}
Ok(result)
}
}
impl<M> MatrixVariance for CsrMatrix<M>
where
M: NumericOps + NumCast,
CsrMatrix<M>: MatrixSum + MatrixNonZero,
{
type Item = M;
fn var_col<I, T>(&self) -> anyhow::Result<Vec<T>>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
T: Float + NumCast + AddAssign + Sum + Send + Sync,
{
let sum: Vec<T> = self.sum_col()?;
let squared_sums: Vec<T> = self.sum_col_squared()?;
let ncols = self.ncols();
let n = T::from(self.nrows()).unwrap();
let n_minus_one = n - T::one();
let total_nnz = self.nnz();
if ncols < 50_000 || total_nnz < 100_000 {
let mut result = vec![T::zero(); ncols];
for col in 0..ncols {
let mean = sum[col] / n;
let population_var = squared_sums[col] / n - mean.powi(2);
if n_minus_one > T::zero() {
result[col] = population_var * (n / n_minus_one)
} else {
result[col] = T::zero();
}
}
Ok(result)
} else {
let result: Vec<T> = (0..ncols)
.into_par_iter()
.map(|col| {
let mean = sum[col] / n;
let population_var = squared_sums[col] / n - mean.powi(2);
if n_minus_one > T::zero() {
population_var * (n / n_minus_one)
} else {
T::zero()
}
})
.collect();
Ok(result)
}
}
fn var_row<I, T>(&self) -> anyhow::Result<Vec<T>>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
T: Float + NumCast + AddAssign + Sum + Send + Sync,
Self::Item: NumCast,
{
let sum: Vec<T> = self.sum_row()?;
let squared_sums: Vec<T> = self.sum_row_squared()?;
let nrows = self.nrows();
let n = T::from(nrows).unwrap();
let n_minus_one = n - T::one();
if nrows < PARALLEL_THRESHOLD {
let mut result = vec![T::zero(); nrows];
for row in 0..nrows {
let mean = sum[row] / n;
let population_var = squared_sums[row] / n - mean.powi(2);
if n_minus_one > T::zero() {
result[row] = population_var * (n / n_minus_one);
} else {
result[row] = T::zero();
}
}
Ok(result)
} else {
let result: Vec<T> = (0..nrows)
.into_par_iter()
.map(|row| {
let mean = sum[row] / n;
let population_var = squared_sums[row] / n - mean.powi(2);
if n_minus_one > T::zero() {
population_var * (n / n_minus_one)
} else {
T::zero()
}
})
.collect();
Ok(result)
}
}
fn var_col_chunk<I, T>(&self, reference: &mut [T]) -> anyhow::Result<()>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
T: Float + NumCast + AddAssign + Sum + Send + Sync,
Self::Item: NumCast,
{
let ncols = self.ncols();
if reference.len() != ncols {
return Err(anyhow::anyhow!(
"Reference slice length {} does not match number of columns {}",
reference.len(),
ncols
));
}
let sum: Vec<T> = self.sum_col()?;
let count: Vec<I> = self.nonzero_col()?;
let mut squared_sums = vec![T::zero(); ncols];
for (value, &col) in self.values().iter().zip(self.col_indices().iter()) {
if let Some(val) = T::from(*value) {
squared_sums[col] += val * val;
}
}
for col in 0..ncols {
reference[col] = if count[col] > I::zero() {
let mean = sum[col] / count[col].into();
squared_sums[col] / count[col].into() - mean * mean
} else {
T::zero()
};
}
Ok(())
}
fn var_row_chunk<I, T>(&self, reference: &mut [T]) -> anyhow::Result<()>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
T: Float + NumCast + AddAssign + Sum + Send + Sync,
Self::Item: NumCast,
{
let nrows = self.nrows();
if reference.len() != nrows {
return Err(anyhow::anyhow!(
"Reference slice length {} does not match number of rows {}",
reference.len(),
nrows
));
}
let sum: Vec<T> = self.sum_row()?;
let count: Vec<I> = self.nonzero_row()?;
for row in 0..nrows {
let row_start = self.row_offsets()[row];
let row_end = self
.row_offsets()
.get(row + 1)
.copied()
.unwrap_or(self.values().len());
reference[row] = if count[row] > I::zero() {
let mean = sum[row] / count[row].into();
self.values()[row_start..row_end]
.iter()
.filter_map(|&v| T::from(v))
.map(|v| {
let diff = v - mean;
diff * diff
})
.sum::<T>()
/ count[row].into()
} else {
T::zero()
};
}
Ok(())
}
fn var_col_masked<I, T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
T: Float + NumCast + AddAssign + Sum + Send + Sync,
{
if mask.len() < self.nrows() {
return Err(anyhow::anyhow!(
"Mask length ({}) is less than number of rows ({})",
mask.len(),
self.nrows()
));
}
let sum: Vec<T> = self.sum_col_masked(mask)?;
let count: Vec<I> = self.nonzero_col_masked(mask)?;
let mut result = vec![T::zero(); self.ncols()];
let mut squared_sums = vec![T::zero(); self.ncols()];
for row in 0..self.nrows() {
if !mask[row] {
continue;
}
let row_start = self.row_offsets()[row];
let row_end = self.row_offsets()[row + 1];
for idx in row_start..row_end {
let col = self.col_indices()[idx];
let val = T::from(self.values()[idx]).unwrap();
squared_sums[col] += val * val;
}
}
for col in 0..self.ncols() {
if count[col] > I::zero() {
let mean = sum[col] / count[col].into();
result[col] = squared_sums[col] / count[col].into() - mean * mean;
}
}
Ok(result)
}
fn var_row_masked<I, T>(&self, mask: &[bool]) -> anyhow::Result<Vec<T>>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T> + Send + Sync,
T: Float + NumCast + AddAssign + Sum + Send + Sync,
{
if mask.len() < self.ncols() {
return Err(anyhow::anyhow!(
"Mask length ({}) is less than number of columns ({})",
mask.len(),
self.ncols()
));
}
let sum: Vec<T> = self.sum_row_masked(mask)?;
let count: Vec<I> = self.nonzero_row_masked(mask)?;
let mut result = vec![T::zero(); self.nrows()];
for row in 0..self.nrows() {
if count[row] > I::zero() {
let mean = sum[row] / count[row].into();
let row_start = self.row_offsets()[row];
let row_end = self.row_offsets()[row + 1];
let mut sum_sq_diff = T::zero();
for idx in row_start..row_end {
let col = self.col_indices()[idx];
if !mask[col] {
continue;
}
let val = T::from(self.values()[idx]).unwrap();
let diff = val - mean;
sum_sq_diff += diff * diff;
}
result[row] = sum_sq_diff / count[row].into();
}
}
Ok(result)
}
}
impl<M: NumCast + Copy + PartialOrd + NumericOps> MatrixMinMax for CsrMatrix<M> {
type Item = M;
fn min_max_col<Item>(&self) -> anyhow::Result<(Vec<Item>, Vec<Item>)>
where
Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync,
{
let mut min: Vec<Item> = vec![Item::max_value(); self.ncols()];
let mut max: Vec<Item> = vec![Item::min_value(); self.ncols()];
self.min_max_col_chunk((&mut min, &mut max))?;
Ok((min, max))
}
fn min_max_row<Item>(&self) -> anyhow::Result<(Vec<Item>, Vec<Item>)>
where
Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync,
{
let mut min: Vec<Item> = vec![Item::max_value(); self.nrows()];
let mut max: Vec<Item> = vec![Item::min_value(); self.nrows()];
self.min_max_row_chunk((&mut min, &mut max))?;
Ok((min, max))
}
fn min_max_col_chunk<Item>(&self, reference: (&mut [Item], &mut [Item])) -> anyhow::Result<()>
where
Item: NumCast + Copy + PartialOrd + NumericOps + Send + Sync,
{
let (min_vals, max_vals) = reference;
let row_offsets = self.row_offsets();
let col_indices = self.col_indices();
let values = self.values();
for row in 0..self.nrows() {
let start_idx = row_offsets[row];
let end_idx = row_offsets[row + 1];
for idx in start_idx..end_idx {
let col = col_indices[idx];
let value = Item::from(values[idx]).unwrap();
if value < min_vals[col] {
min_vals[col] = value;
}
if value > max_vals[col] {
max_vals[col] = value;
}
}
}
Ok(())
}
fn min_max_row_chunk<Item>(&self, reference: (&mut [Item], &mut [Item])) -> anyhow::Result<()>
where
Item: NumCast + Copy + PartialOrd + NumericOps,
{
let (min_vals, max_vals) = reference;
let row_offsets = self.row_offsets();
let values = self.values();
(0..self.nrows()).for_each(|row| {
let start_idx = row_offsets[row];
let end_idx = row_offsets[row + 1];
if start_idx < end_idx {
let first_value = Item::from(values[start_idx]).unwrap();
let mut row_min = first_value;
let mut row_max = first_value;
for &value in &values[start_idx..end_idx] {
let value_cast = Item::from(value).unwrap();
if value_cast < row_min {
row_min = value_cast;
}
if value_cast > row_max {
row_max = value_cast;
}
}
min_vals[row] = row_min;
max_vals[row] = row_max;
}
});
Ok(())
}
}
impl<T: FloatOpsTS> Normalize<T> for CsrMatrix<T> {
fn normalize<U: FloatOpsTS>(
&mut self,
sums: &[U],
target: U,
direction: &Direction,
) -> anyhow::Result<()> {
let scaling_factors: Vec<U> = sums
.iter()
.map(|&sum| {
if sum > U::zero() {
target / sum
} else {
U::zero()
}
})
.collect();
match direction {
Direction::COLUMN => {
let col_indices = self.col_indices().to_vec();
let values = self.values_mut();
for (val, &col) in values.iter_mut().zip(col_indices.iter()) {
let scale = scaling_factors[col];
if scale > U::zero() {
*val = T::from(U::from(*val).unwrap() * scale).unwrap();
}
}
}
Direction::ROW => {
let row_offsets = self.row_offsets().to_vec();
let nrows = self.nrows();
let values = self.values_mut();
for row in 0..nrows {
let scale = scaling_factors[row];
if scale > U::zero() {
let start = row_offsets[row];
let end = row_offsets[row + 1];
for val in &mut values[start..end] {
*val = T::from(U::from(*val).unwrap() * scale).unwrap();
}
}
}
}
}
Ok(())
}
}
impl<T: FloatOpsTS> Log1P<T> for CsrMatrix<T> {
fn log1p_normalize(&mut self) -> anyhow::Result<()> {
let values = self.values_mut();
for val in values.iter_mut() {
*val = T::one() + *val;
*val = val.ln();
}
Ok(())
}
}
impl<M> BatchMatrixVariance for CsrMatrix<M>
where
M: NumericOps + NumCast,
CsrMatrix<M>: MatrixSum + MatrixNonZero,
{
type Item = M;
fn var_batch_row<I, T, B>(&self, batches: &[B]) -> anyhow::Result<HashMap<B, Vec<T>>>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
T: Float + NumCast + AddAssign + Sum,
B: BatchIdentifier,
{
if batches.len() != self.nrows() {
return Err(anyhow::anyhow!(
"Batch vector length ({}) doesn't match matrix row count ({})",
batches.len(),
self.nrows()
));
}
let mut batch_indices: HashMap<B, Vec<usize>> = HashMap::new();
for (idx, batch) in batches.iter().enumerate() {
batch_indices.entry(batch.clone()).or_default().push(idx);
}
let mut result: HashMap<B, Vec<T>> = HashMap::new();
for (batch, indices) in batch_indices {
let mut batch_vars = vec![T::zero(); self.ncols()];
let mut batch_means = vec![T::zero(); self.ncols()];
let mut batch_counts = vec![0usize; self.ncols()];
let mut batch_sum_sq = vec![T::zero(); self.ncols()];
for &row_idx in &indices {
let row_start = self.row_offsets()[row_idx];
let row_end = self.row_offsets()[row_idx + 1];
for j in row_start..row_end {
let col = self.col_indices()[j];
let val = T::from(self.values()[j]).unwrap();
batch_means[col] = batch_means[col] + val;
batch_counts[col] += 1;
}
}
for (mean, &count) in batch_means.iter_mut().zip(batch_counts.iter()) {
if count > 0 {
*mean = *mean / T::from(count).unwrap();
}
}
for &row_idx in &indices {
let row_start = self.row_offsets()[row_idx];
let row_end = self.row_offsets()[row_idx + 1];
for j in row_start..row_end {
let col = self.col_indices()[j];
let val = T::from(self.values()[j]).unwrap();
let diff = val - batch_means[col];
batch_sum_sq[col] = batch_sum_sq[col] + diff * diff;
}
}
for ((var, &count), &sum_sq) in batch_vars
.iter_mut()
.zip(batch_counts.iter())
.zip(batch_sum_sq.iter())
{
if count > 1 {
*var = sum_sq / T::from(count - 1).unwrap();
}
}
result.insert(batch, batch_vars);
}
Ok(result)
}
fn var_batch_col<I, T, B>(&self, batches: &[B]) -> anyhow::Result<HashMap<B, Vec<T>>>
where
I: PrimInt + Unsigned + Zero + AddAssign + Into<T>,
T: Float + NumCast + AddAssign + Sum,
B: BatchIdentifier,
{
if batches.len() != self.ncols() {
return Err(anyhow::anyhow!(
"Batch vector length ({}) doesn't match matrix column count ({})",
batches.len(),
self.ncols()
));
}
let col_to_batch: Vec<&B> = batches.iter().collect();
let mut batch_columns: HashMap<B, Vec<usize>> = HashMap::new();
for (col_idx, &batch) in col_to_batch.iter().enumerate() {
batch_columns
.entry(batch.clone())
.or_default()
.push(col_idx);
}
let mut result: HashMap<B, Vec<T>> = HashMap::new();
for (batch, col_indices) in batch_columns {
let mut batch_vars = vec![T::zero(); self.nrows()];
let mut row_values: Vec<Vec<T>> = vec![Vec::new(); self.nrows()];
for row_idx in 0..self.nrows() {
let row_start = self.row_offsets()[row_idx];
let row_end = self.row_offsets()[row_idx + 1];
for j in row_start..row_end {
let col = self.col_indices()[j];
if col_indices.contains(&col) {
let val = T::from(self.values()[j]).unwrap();
row_values[row_idx].push(val);
}
}
}
for (row_idx, values) in row_values.iter().enumerate() {
if values.len() > 1 {
let mean = values.iter().copied().sum::<T>() / T::from(values.len()).unwrap();
let sum_sq_diff = values
.iter()
.map(|&val| {
let diff = val - mean;
diff * diff
})
.sum::<T>();
batch_vars[row_idx] = sum_sq_diff / T::from(values.len() - 1).unwrap();
}
}
result.insert(batch, batch_vars);
}
Ok(result)
}
}
impl<M: NumericOps + NumCast> BatchMatrixMean for CsrMatrix<M> {
type Item = M;
fn mean_batch_row<T, B>(&self, batches: &[B]) -> anyhow::Result<HashMap<B, Vec<T>>>
where
T: Float + NumCast + AddAssign + std::iter::Sum,
B: BatchIdentifier,
{
if batches.len() != self.ncols() {
return Err(anyhow::anyhow!(
"Number of batch identifiers ({}) must match number of columns ({})",
batches.len(),
self.ncols()
));
}
let mut batch_indices: HashMap<B, Vec<usize>> = HashMap::new();
for (col_idx, batch) in batches.iter().enumerate() {
batch_indices
.entry(batch.clone())
.or_default()
.push(col_idx);
}
let mut result: HashMap<B, Vec<T>> = HashMap::new();
for (batch, col_indices) in batch_indices {
let mut batch_means = vec![T::zero(); self.nrows()];
for &col_idx in &col_indices {
for row in 0..self.nrows() {
if let Some(entry) = self.get_entry(row, col_idx) {
batch_means[row] += T::from(entry.into_value()).unwrap();
}
}
}
let col_count = T::from(col_indices.len()).unwrap();
for mean in &mut batch_means {
*mean = *mean / col_count;
}
result.insert(batch, batch_means);
}
Ok(result)
}
fn mean_batch_col<T, B>(&self, batches: &[B]) -> anyhow::Result<HashMap<B, Vec<T>>>
where
T: Float + NumCast + AddAssign + std::iter::Sum,
B: BatchIdentifier,
{
if batches.len() != self.nrows() {
return Err(anyhow::anyhow!(
"Number of batch identifiers ({}) must match number of rows ({})",
batches.len(),
self.nrows()
));
}
let mut batch_indices: HashMap<B, Vec<usize>> = HashMap::new();
for (row_idx, batch) in batches.iter().enumerate() {
batch_indices
.entry(batch.clone())
.or_default()
.push(row_idx);
}
let mut result: HashMap<B, Vec<T>> = HashMap::new();
for (batch, row_indices) in batch_indices {
let mut batch_means = vec![T::zero(); self.ncols()];
for &row_idx in &row_indices {
for (col_idx, _, value) in self.triplet_iter().filter(|&(row, _, _)| row == row_idx)
{
batch_means[col_idx] += T::from(*value).unwrap();
}
}
let row_count = T::from(row_indices.len()).unwrap();
for mean in &mut batch_means {
*mean = *mean / row_count;
}
result.insert(batch, batch_means);
}
Ok(result)
}
}
impl<M: NumericOps + NumCast> MatrixNTop for CsrMatrix<M> {
type Item = M;
fn sum_row_n_top<T>(&self, n: usize) -> anyhow::Result<Vec<T>>
where
T: Float + NumCast + AddAssign + Sum,
{
let mut result = vec![T::zero(); self.nrows()];
for row_idx in 0..self.nrows() {
let row_start = self.row_offsets()[row_idx];
let row_end = self.row_offsets()[row_idx + 1];
let mut row_values: Vec<T> = Vec::new();
for idx in row_start..row_end {
if let Some(val) = T::from(self.values()[idx]) {
row_values.push(val);
}
}
if row_values.len() <= n {
result[row_idx] = row_values.into_iter().sum();
} else {
row_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
result[row_idx] = row_values.into_iter().take(n).sum();
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use Direction;
use super::*;
use nalgebra_sparse::{CooMatrix, CscMatrix};
fn create_test_matrix() -> CscMatrix<f64> {
let mut coo = CooMatrix::new(4, 3);
coo.push(0, 0, 1.0); coo.push(2, 0, 3.0);
coo.push(2, 1, 4.0); coo.push(3, 1, 5.0);
coo.push(0, 2, 2.0); coo.push(3, 2, 6.0);
CscMatrix::from(&coo)
}
#[test]
fn test_nonzero_col() {
let matrix = create_test_matrix();
let result: Vec<u32> = matrix.nonzero_col().unwrap();
assert_eq!(result, vec![2, 2, 2]);
}
#[test]
fn test_nonzero_row() {
let matrix = create_test_matrix();
let result: Vec<u32> = matrix.nonzero_row().unwrap();
assert_eq!(result, vec![2, 0, 2, 2]);
}
#[test]
fn test_nonzero_col_chunk() {
let matrix = create_test_matrix();
let mut reference = vec![0u32; 4];
matrix.nonzero_col_chunk(&mut reference).unwrap();
assert_eq!(reference, vec![2, 2, 2, 0]);
}
#[test]
fn test_nonzero_row_chunk() {
let matrix = create_test_matrix();
let mut reference = vec![0u32; 3];
matrix.nonzero_row_chunk(&mut reference).unwrap();
assert_eq!(reference, vec![2, 0, 2]);
}
#[test]
fn test_empty_matrix() {
let matrix: CscMatrix<f64> = CscMatrix::zeros(0, 0);
assert!(matrix.nonzero_col::<u32>().unwrap().is_empty());
assert!(matrix.nonzero_row::<u32>().unwrap().is_empty());
let mut empty_ref: Vec<u32> = Vec::new();
assert!(matrix.nonzero_col_chunk(&mut empty_ref).is_ok());
assert!(matrix.nonzero_row_chunk(&mut empty_ref).is_ok());
}
#[test]
fn test_different_integer_types() {
let matrix = create_test_matrix();
let result_u8: Vec<u8> = matrix.nonzero_col().unwrap();
assert_eq!(result_u8, vec![2, 2, 2]);
let result_u64: Vec<u64> = matrix.nonzero_col().unwrap();
assert_eq!(result_u64, vec![2, 2, 2]);
}
#[test]
fn test_large_sparse_matrix() {
let mut coo = CooMatrix::new(1000, 1000);
for i in 0..999 {
coo.push(i, i, 1.0);
coo.push(i + 1, i, 1.0);
}
let matrix = CscMatrix::from(&coo);
let result: Vec<u32> = matrix.nonzero_col().unwrap();
assert_eq!(result.len(), 1000);
assert_eq!(result[500], 2);
}
#[test]
fn test_chunk_smaller_than_matrix() {
let matrix = create_test_matrix();
let mut col_ref = vec![0u32; 2];
matrix.nonzero_col_chunk(&mut col_ref).unwrap();
assert_eq!(col_ref, vec![2, 2]);
let mut row_ref = vec![0u32; 2];
matrix.nonzero_row_chunk(&mut row_ref).unwrap();
assert_eq!(row_ref, vec![2, 0]);
}
#[test]
fn test_zero_matrix() {
let matrix: CscMatrix<f64> = CscMatrix::zeros(5, 4);
let col_result: Vec<u32> = matrix.nonzero_col().unwrap();
assert_eq!(col_result, vec![0, 0, 0, 0]);
let row_result: Vec<u32> = matrix.nonzero_row().unwrap();
assert_eq!(row_result, vec![0, 0, 0, 0, 0]);
}
#[test]
fn test_csr_normalize() {
let coo = CooMatrix::try_from_triplets(
3,
3,
vec![0, 0, 1, 1, 2], vec![0, 1, 1, 2, 2], vec![2.0, 3.0, 4.0, 1.0, 2.0], )
.unwrap();
let mut csr: CsrMatrix<f64> = (&coo).into();
let col_sums = vec![2.0, 7.0, 3.0]; let target = 1.0;
csr.normalize(&col_sums, target, &Direction::COLUMN)
.unwrap();
let expected_values = [1.0, 3.0 / 7.0, 4.0 / 7.0, 1.0 / 3.0, 2.0 / 3.0];
for ((_, _, val), expected) in csr.triplet_iter().zip(expected_values.iter()) {
assert!((val - expected).abs() < 1e-10);
}
let mut csr: CsrMatrix<f64> = (&coo).into(); let row_sums = vec![5.0, 5.0, 2.0]; csr.normalize(&row_sums, target, &Direction::ROW).unwrap();
let expected_values = [0.4, 0.6, 0.8, 0.2, 1.0];
for ((_, _, val), expected) in csr.triplet_iter().zip(expected_values.iter()) {
assert!((val - expected).abs() < 1e-10);
}
}
}