use scirs2_core::ndarray::{Array, ArrayD, IxDyn};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SparseError {
#[error("Shape mismatch: expected {expected:?}, got {got:?}")]
ShapeMismatch { expected: Vec<usize>, got: Vec<usize> },
#[error("Index out of bounds: index {index:?} for shape {shape:?}")]
IndexOutOfBounds { index: Vec<usize>, shape: Vec<usize> },
#[error("Invalid format: {0}")]
InvalidFormat(String),
#[error("Unsupported operation: {0}")]
UnsupportedOperation(String),
#[error("Empty tensor")]
EmptyTensor,
}
pub type SparseResult<T> = Result<T, SparseError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SparseFormat {
COO,
CSR,
CSC,
}
impl Default for SparseFormat {
fn default() -> Self {
SparseFormat::COO
}
}
#[derive(Debug, Clone)]
pub struct SparseCOO {
pub shape: Vec<usize>,
pub indices: Vec<Vec<usize>>,
pub values: Vec<f64>,
sorted: bool,
}
impl SparseCOO {
pub fn new(shape: Vec<usize>) -> Self {
SparseCOO {
shape,
indices: Vec::new(),
values: Vec::new(),
sorted: true,
}
}
pub fn from_triplets(shape: Vec<usize>, triplets: Vec<(usize, usize, f64)>) -> Self {
let mut indices = Vec::with_capacity(triplets.len());
let mut values = Vec::with_capacity(triplets.len());
for (row, col, val) in triplets {
if val.abs() > 1e-15 {
indices.push(vec![row, col]);
values.push(val);
}
}
SparseCOO {
shape,
indices,
values,
sorted: false,
}
}
pub fn from_dense(dense: &ArrayD<f64>) -> Self {
let shape = dense.shape().to_vec();
let mut indices = Vec::new();
let mut values = Vec::new();
for (idx, &val) in dense.indexed_iter() {
if val.abs() > 1e-15 {
indices.push(idx.as_array_view().to_vec());
values.push(val);
}
}
SparseCOO {
shape,
indices,
values,
sorted: false,
}
}
pub fn add(&mut self, indices: Vec<usize>, value: f64) -> SparseResult<()> {
if indices.len() != self.shape.len() {
return Err(SparseError::IndexOutOfBounds {
index: indices,
shape: self.shape.clone(),
});
}
for (i, &idx) in indices.iter().enumerate() {
if idx >= self.shape[i] {
return Err(SparseError::IndexOutOfBounds {
index: indices,
shape: self.shape.clone(),
});
}
}
if value.abs() > 1e-15 {
self.indices.push(indices);
self.values.push(value);
self.sorted = false;
}
Ok(())
}
pub fn get(&self, target: &[usize]) -> f64 {
for (i, indices) in self.indices.iter().enumerate() {
if indices == target {
return self.values[i];
}
}
0.0
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn sparsity(&self) -> f64 {
let total = self.numel();
if total == 0 {
return 1.0;
}
1.0 - (self.nnz() as f64 / total as f64)
}
pub fn to_dense(&self) -> ArrayD<f64> {
let mut dense = ArrayD::zeros(IxDyn(&self.shape));
for (indices, &value) in self.indices.iter().zip(self.values.iter()) {
dense[IxDyn(indices)] = value;
}
dense
}
pub fn sort(&mut self) {
if self.sorted {
return;
}
let mut perm: Vec<usize> = (0..self.nnz()).collect();
perm.sort_by(|&a, &b| self.indices[a].cmp(&self.indices[b]));
let new_indices: Vec<Vec<usize>> = perm.iter().map(|&i| self.indices[i].clone()).collect();
let new_values: Vec<f64> = perm.iter().map(|&i| self.values[i]).collect();
self.indices = new_indices;
self.values = new_values;
self.sorted = true;
}
pub fn sum_duplicates(&mut self) {
if self.indices.is_empty() {
return;
}
self.sort();
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
let mut current_idx = self.indices[0].clone();
let mut current_val = self.values[0];
for i in 1..self.nnz() {
if self.indices[i] == current_idx {
current_val += self.values[i];
} else {
if current_val.abs() > 1e-15 {
new_indices.push(current_idx);
new_values.push(current_val);
}
current_idx = self.indices[i].clone();
current_val = self.values[i];
}
}
if current_val.abs() > 1e-15 {
new_indices.push(current_idx);
new_values.push(current_val);
}
self.indices = new_indices;
self.values = new_values;
}
}
#[derive(Debug, Clone)]
pub struct SparseCSR {
pub shape: (usize, usize),
pub row_ptr: Vec<usize>,
pub col_idx: Vec<usize>,
pub values: Vec<f64>,
}
impl SparseCSR {
pub fn new(shape: (usize, usize)) -> Self {
SparseCSR {
shape,
row_ptr: vec![0; shape.0 + 1],
col_idx: Vec::new(),
values: Vec::new(),
}
}
pub fn from_coo(coo: &SparseCOO) -> SparseResult<Self> {
if coo.shape.len() != 2 {
return Err(SparseError::InvalidFormat(
"CSR only supports 2D tensors".to_string(),
));
}
let nrows = coo.shape[0];
let ncols = coo.shape[1];
let mut row_counts = vec![0usize; nrows];
for indices in &coo.indices {
row_counts[indices[0]] += 1;
}
let mut row_ptr = vec![0usize; nrows + 1];
for i in 0..nrows {
row_ptr[i + 1] = row_ptr[i] + row_counts[i];
}
let nnz = coo.nnz();
let mut col_idx = vec![0usize; nnz];
let mut values = vec![0.0f64; nnz];
let mut row_offsets = row_ptr.clone();
for (indices, &value) in coo.indices.iter().zip(coo.values.iter()) {
let row = indices[0];
let col = indices[1];
let pos = row_offsets[row];
col_idx[pos] = col;
values[pos] = value;
row_offsets[row] += 1;
}
for row in 0..nrows {
let start = row_ptr[row];
let end = row_ptr[row + 1];
if end > start {
for i in (start + 1)..end {
let mut j = i;
while j > start && col_idx[j - 1] > col_idx[j] {
col_idx.swap(j - 1, j);
values.swap(j - 1, j);
j -= 1;
}
}
}
}
Ok(SparseCSR {
shape: (nrows, ncols),
row_ptr,
col_idx,
values,
})
}
pub fn from_dense(dense: &ArrayD<f64>) -> SparseResult<Self> {
if dense.ndim() != 2 {
return Err(SparseError::InvalidFormat(
"CSR only supports 2D tensors".to_string(),
));
}
let coo = SparseCOO::from_dense(dense);
Self::from_coo(&coo)
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn get_row(&self, row: usize) -> Option<(&[usize], &[f64])> {
if row >= self.shape.0 {
return None;
}
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
Some((&self.col_idx[start..end], &self.values[start..end]))
}
pub fn get(&self, row: usize, col: usize) -> f64 {
if row >= self.shape.0 || col >= self.shape.1 {
return 0.0;
}
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
match self.col_idx[start..end].binary_search(&col) {
Ok(pos) => self.values[start + pos],
Err(_) => 0.0,
}
}
pub fn to_dense(&self) -> ArrayD<f64> {
let mut dense = ArrayD::zeros(IxDyn(&[self.shape.0, self.shape.1]));
for row in 0..self.shape.0 {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for (col_pos, &col) in self.col_idx[start..end].iter().enumerate() {
dense[[row, col]] = self.values[start + col_pos];
}
}
dense
}
pub fn to_coo(&self) -> SparseCOO {
let mut indices = Vec::with_capacity(self.nnz());
let mut values = Vec::with_capacity(self.nnz());
for row in 0..self.shape.0 {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for (i, &col) in self.col_idx[start..end].iter().enumerate() {
indices.push(vec![row, col]);
values.push(self.values[start + i]);
}
}
SparseCOO {
shape: vec![self.shape.0, self.shape.1],
indices,
values,
sorted: true,
}
}
pub fn matvec(&self, x: &[f64]) -> SparseResult<Vec<f64>> {
if x.len() != self.shape.1 {
return Err(SparseError::ShapeMismatch {
expected: vec![self.shape.1],
got: vec![x.len()],
});
}
let mut y = vec![0.0; self.shape.0];
for row in 0..self.shape.0 {
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
for (i, &col) in self.col_idx[start..end].iter().enumerate() {
y[row] += self.values[start + i] * x[col];
}
}
Ok(y)
}
}
#[derive(Debug, Clone)]
pub struct SparseTensor {
format: SparseFormat,
coo: SparseCOO,
csr: Option<SparseCSR>,
}
impl SparseTensor {
pub fn from_coo(coo: SparseCOO) -> Self {
SparseTensor {
format: SparseFormat::COO,
coo,
csr: None,
}
}
pub fn from_csr(csr: SparseCSR) -> Self {
let coo = csr.to_coo();
SparseTensor {
format: SparseFormat::CSR,
coo,
csr: Some(csr),
}
}
pub fn from_dense(dense: &ArrayD<f64>) -> Self {
let coo = SparseCOO::from_dense(dense);
SparseTensor {
format: SparseFormat::COO,
coo,
csr: None,
}
}
pub fn zeros(shape: Vec<usize>) -> Self {
SparseTensor {
format: SparseFormat::COO,
coo: SparseCOO::new(shape),
csr: None,
}
}
pub fn eye(n: usize) -> Self {
let triplets: Vec<_> = (0..n).map(|i| (i, i, 1.0)).collect();
let coo = SparseCOO::from_triplets(vec![n, n], triplets);
SparseTensor::from_coo(coo)
}
pub fn shape(&self) -> &[usize] {
&self.coo.shape
}
pub fn ndim(&self) -> usize {
self.coo.shape.len()
}
pub fn nnz(&self) -> usize {
self.coo.nnz()
}
pub fn numel(&self) -> usize {
self.coo.numel()
}
pub fn sparsity(&self) -> f64 {
self.coo.sparsity()
}
pub fn memory_savings(&self) -> f64 {
let dense_bytes = self.numel() * 8; let sparse_bytes = self.nnz() * (8 + self.ndim() * 8); if dense_bytes == 0 {
return 0.0;
}
1.0 - (sparse_bytes as f64 / dense_bytes as f64)
}
pub fn to_dense(&self) -> ArrayD<f64> {
self.coo.to_dense()
}
pub fn as_coo(&self) -> &SparseCOO {
&self.coo
}
pub fn as_csr(&mut self) -> SparseResult<&SparseCSR> {
if self.csr.is_none() {
self.csr = Some(SparseCSR::from_coo(&self.coo)?);
}
Ok(self.csr.as_ref().expect("csr was just set above when None"))
}
pub fn get(&self, indices: &[usize]) -> f64 {
self.coo.get(indices)
}
pub fn map<F>(&self, f: F) -> Self
where
F: Fn(f64) -> f64,
{
let new_values: Vec<f64> = self.coo.values.iter().map(|&v| f(v)).collect();
let mut new_indices = Vec::new();
let mut filtered_values = Vec::new();
for (indices, &value) in self.coo.indices.iter().zip(new_values.iter()) {
if value.abs() > 1e-15 {
new_indices.push(indices.clone());
filtered_values.push(value);
}
}
let coo = SparseCOO {
shape: self.coo.shape.clone(),
indices: new_indices,
values: filtered_values,
sorted: self.coo.sorted,
};
SparseTensor::from_coo(coo)
}
pub fn scale(&self, scalar: f64) -> Self {
if scalar.abs() < 1e-15 {
return SparseTensor::zeros(self.coo.shape.clone());
}
let new_values: Vec<f64> = self.coo.values.iter().map(|&v| v * scalar).collect();
let coo = SparseCOO {
shape: self.coo.shape.clone(),
indices: self.coo.indices.clone(),
values: new_values,
sorted: self.coo.sorted,
};
SparseTensor::from_coo(coo)
}
pub fn add(&self, other: &SparseTensor) -> SparseResult<SparseTensor> {
if self.coo.shape != other.coo.shape {
return Err(SparseError::ShapeMismatch {
expected: self.coo.shape.clone(),
got: other.coo.shape.clone(),
});
}
let mut values_map: HashMap<Vec<usize>, f64> = HashMap::new();
for (indices, &value) in self.coo.indices.iter().zip(self.coo.values.iter()) {
*values_map.entry(indices.clone()).or_insert(0.0) += value;
}
for (indices, &value) in other.coo.indices.iter().zip(other.coo.values.iter()) {
*values_map.entry(indices.clone()).or_insert(0.0) += value;
}
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
for (indices, value) in values_map {
if value.abs() > 1e-15 {
new_indices.push(indices);
new_values.push(value);
}
}
let mut coo = SparseCOO {
shape: self.coo.shape.clone(),
indices: new_indices,
values: new_values,
sorted: false,
};
coo.sort();
Ok(SparseTensor::from_coo(coo))
}
pub fn hadamard(&self, other: &SparseTensor) -> SparseResult<SparseTensor> {
if self.coo.shape != other.coo.shape {
return Err(SparseError::ShapeMismatch {
expected: self.coo.shape.clone(),
got: other.coo.shape.clone(),
});
}
let other_map: HashMap<&Vec<usize>, f64> = other
.coo
.indices
.iter()
.zip(other.coo.values.iter())
.map(|(idx, &val)| (idx, val))
.collect();
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
for (indices, &value) in self.coo.indices.iter().zip(self.coo.values.iter()) {
if let Some(&other_value) = other_map.get(indices) {
let product = value * other_value;
if product.abs() > 1e-15 {
new_indices.push(indices.clone());
new_values.push(product);
}
}
}
let coo = SparseCOO {
shape: self.coo.shape.clone(),
indices: new_indices,
values: new_values,
sorted: self.coo.sorted,
};
Ok(SparseTensor::from_coo(coo))
}
pub fn maximum(&self, other: &SparseTensor) -> SparseResult<SparseTensor> {
if self.coo.shape != other.coo.shape {
return Err(SparseError::ShapeMismatch {
expected: self.coo.shape.clone(),
got: other.coo.shape.clone(),
});
}
let mut values_map: HashMap<Vec<usize>, f64> = HashMap::new();
for (indices, &value) in self.coo.indices.iter().zip(self.coo.values.iter()) {
values_map
.entry(indices.clone())
.and_modify(|v| *v = v.max(value))
.or_insert(value);
}
for (indices, &value) in other.coo.indices.iter().zip(other.coo.values.iter()) {
values_map
.entry(indices.clone())
.and_modify(|v| *v = v.max(value))
.or_insert(value);
}
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
for (indices, value) in values_map {
if value.abs() > 1e-15 {
new_indices.push(indices);
new_values.push(value);
}
}
let mut coo = SparseCOO {
shape: self.coo.shape.clone(),
indices: new_indices,
values: new_values,
sorted: false,
};
coo.sort();
Ok(SparseTensor::from_coo(coo))
}
pub fn sum(&self, axes: &[usize]) -> SparseResult<SparseTensor> {
if axes.is_empty() {
return Ok(self.clone());
}
for &axis in axes {
if axis >= self.ndim() {
return Err(SparseError::IndexOutOfBounds {
index: vec![axis],
shape: self.coo.shape.clone(),
});
}
}
let kept_dims: Vec<usize> = (0..self.ndim()).filter(|d| !axes.contains(d)).collect();
let new_shape: Vec<usize> = kept_dims.iter().map(|&d| self.coo.shape[d]).collect();
if new_shape.is_empty() {
let total: f64 = self.coo.values.iter().sum();
let mut coo = SparseCOO::new(vec![]);
if total.abs() > 1e-15 {
coo.indices.push(vec![]);
coo.values.push(total);
}
return Ok(SparseTensor::from_coo(coo));
}
let mut values_map: HashMap<Vec<usize>, f64> = HashMap::new();
for (indices, &value) in self.coo.indices.iter().zip(self.coo.values.iter()) {
let new_indices: Vec<usize> = kept_dims.iter().map(|&d| indices[d]).collect();
*values_map.entry(new_indices).or_insert(0.0) += value;
}
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
for (indices, value) in values_map {
if value.abs() > 1e-15 {
new_indices.push(indices);
new_values.push(value);
}
}
let mut coo = SparseCOO {
shape: new_shape,
indices: new_indices,
values: new_values,
sorted: false,
};
coo.sort();
Ok(SparseTensor::from_coo(coo))
}
pub fn max(&self, axes: &[usize]) -> SparseResult<SparseTensor> {
if axes.is_empty() {
return Ok(self.clone());
}
for &axis in axes {
if axis >= self.ndim() {
return Err(SparseError::IndexOutOfBounds {
index: vec![axis],
shape: self.coo.shape.clone(),
});
}
}
let kept_dims: Vec<usize> = (0..self.ndim()).filter(|d| !axes.contains(d)).collect();
let new_shape: Vec<usize> = kept_dims.iter().map(|&d| self.coo.shape[d]).collect();
if new_shape.is_empty() {
let max_val = self
.coo
.values
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
let mut coo = SparseCOO::new(vec![]);
if max_val.abs() > 1e-15 && max_val.is_finite() {
coo.indices.push(vec![]);
coo.values.push(max_val);
}
return Ok(SparseTensor::from_coo(coo));
}
let mut values_map: HashMap<Vec<usize>, f64> = HashMap::new();
for (indices, &value) in self.coo.indices.iter().zip(self.coo.values.iter()) {
let new_indices: Vec<usize> = kept_dims.iter().map(|&d| indices[d]).collect();
values_map
.entry(new_indices)
.and_modify(|v| *v = v.max(value))
.or_insert(value);
}
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
for (indices, value) in values_map {
if value.abs() > 1e-15 {
new_indices.push(indices);
new_values.push(value);
}
}
let mut coo = SparseCOO {
shape: new_shape,
indices: new_indices,
values: new_values,
sorted: false,
};
coo.sort();
Ok(SparseTensor::from_coo(coo))
}
}
#[derive(Debug, Clone)]
pub struct SparseConfig {
pub zero_threshold: f64,
pub sparsity_threshold: f64,
pub preferred_format: SparseFormat,
}
impl Default for SparseConfig {
fn default() -> Self {
SparseConfig {
zero_threshold: 1e-15,
sparsity_threshold: 0.5,
preferred_format: SparseFormat::COO,
}
}
}
impl SparseConfig {
pub fn should_use_sparse(&self, dense: &ArrayD<f64>) -> bool {
let total = dense.len();
if total == 0 {
return false;
}
let nnz = dense.iter().filter(|&&v| v.abs() > self.zero_threshold).count();
let sparsity = 1.0 - (nnz as f64 / total as f64);
sparsity > self.sparsity_threshold
}
}
#[derive(Debug, Clone)]
pub struct SparseStats {
pub nnz: usize,
pub numel: usize,
pub sparsity: f64,
pub sparse_bytes: usize,
pub dense_bytes: usize,
pub savings: f64,
}
impl SparseStats {
pub fn from_tensor(tensor: &SparseTensor) -> Self {
let nnz = tensor.nnz();
let numel = tensor.numel();
let sparsity = tensor.sparsity();
let dense_bytes = numel * 8;
let sparse_bytes = nnz * (8 + tensor.ndim() * 8);
let savings = if dense_bytes > 0 {
1.0 - (sparse_bytes as f64 / dense_bytes as f64)
} else {
0.0
};
SparseStats {
nnz,
numel,
sparsity,
sparse_bytes,
dense_bytes,
savings,
}
}
}
impl std::fmt::Display for SparseStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SparseStats {{ nnz: {}, numel: {}, sparsity: {:.1}%, savings: {:.1}% }}",
self.nnz,
self.numel,
self.sparsity * 100.0,
self.savings * 100.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coo_from_triplets() {
let coo = SparseCOO::from_triplets(vec![3, 3], vec![(0, 0, 1.0), (1, 1, 2.0), (2, 2, 3.0)]);
assert_eq!(coo.shape, vec![3, 3]);
assert_eq!(coo.nnz(), 3);
assert_eq!(coo.get(&[0, 0]), 1.0);
assert_eq!(coo.get(&[1, 1]), 2.0);
assert_eq!(coo.get(&[2, 2]), 3.0);
assert_eq!(coo.get(&[0, 1]), 0.0); }
#[test]
fn test_coo_from_dense() -> Result<(), Box<dyn std::error::Error>> {
let dense = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0])
.map_err(|e| format!("{e}"))?;
let coo = SparseCOO::from_dense(&dense);
assert_eq!(coo.shape, vec![2, 3]);
assert_eq!(coo.nnz(), 3);
assert!((coo.sparsity() - 0.5).abs() < 0.01);
Ok(())
}
#[test]
fn test_coo_to_dense() {
let coo =
SparseCOO::from_triplets(vec![2, 2], vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0)]);
let dense = coo.to_dense();
assert_eq!(dense[[0, 0]], 1.0);
assert_eq!(dense[[0, 1]], 2.0);
assert_eq!(dense[[1, 0]], 3.0);
assert_eq!(dense[[1, 1]], 0.0);
}
#[test]
fn test_csr_from_coo() -> Result<(), Box<dyn std::error::Error>> {
let coo = SparseCOO::from_triplets(
vec![3, 4],
vec![(0, 0, 1.0), (0, 2, 2.0), (1, 1, 3.0), (2, 3, 4.0)],
);
let csr = SparseCSR::from_coo(&coo)?;
assert_eq!(csr.shape, (3, 4));
assert_eq!(csr.nnz(), 4);
assert_eq!(csr.get(0, 0), 1.0);
assert_eq!(csr.get(0, 2), 2.0);
assert_eq!(csr.get(1, 1), 3.0);
assert_eq!(csr.get(2, 3), 4.0);
assert_eq!(csr.get(0, 1), 0.0);
Ok(())
}
#[test]
fn test_csr_matvec() -> Result<(), Box<dyn std::error::Error>> {
let coo = SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)],
);
let csr = SparseCSR::from_coo(&coo)?;
let x = vec![1.0, 2.0];
let y = csr.matvec(&x)?;
assert!((y[0] - 5.0).abs() < 1e-10);
assert!((y[1] - 11.0).abs() < 1e-10);
Ok(())
}
#[test]
fn test_sparse_tensor_add() -> Result<(), Box<dyn std::error::Error>> {
let a = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 1.0), (0, 1, 2.0)],
));
let b = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 3.0), (1, 1, 4.0)],
));
let c = a.add(&b)?;
assert_eq!(c.get(&[0, 0]), 4.0); assert_eq!(c.get(&[0, 1]), 2.0); assert_eq!(c.get(&[1, 1]), 4.0); Ok(())
}
#[test]
fn test_sparse_tensor_hadamard() -> Result<(), Box<dyn std::error::Error>> {
let a = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 2.0), (0, 1, 3.0), (1, 0, 4.0)],
));
let b = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 5.0), (1, 0, 2.0), (1, 1, 3.0)],
));
let c = a.hadamard(&b)?;
assert_eq!(c.get(&[0, 0]), 10.0); assert_eq!(c.get(&[0, 1]), 0.0); assert_eq!(c.get(&[1, 0]), 8.0); assert_eq!(c.get(&[1, 1]), 0.0); Ok(())
}
#[test]
fn test_sparse_tensor_maximum() -> Result<(), Box<dyn std::error::Error>> {
let a = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 1.0), (0, 1, 5.0)],
));
let b = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 3.0), (1, 1, 4.0)],
));
let c = a.maximum(&b)?;
assert_eq!(c.get(&[0, 0]), 3.0); assert_eq!(c.get(&[0, 1]), 5.0); assert_eq!(c.get(&[1, 1]), 4.0); Ok(())
}
#[test]
fn test_sparse_tensor_sum() -> Result<(), Box<dyn std::error::Error>> {
let sparse = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 3],
vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 2, 4.0)],
));
let summed = sparse.sum(&[1])?;
assert_eq!(summed.shape(), &[2]);
assert_eq!(summed.get(&[0]), 3.0); assert_eq!(summed.get(&[1]), 7.0); Ok(())
}
#[test]
fn test_sparse_tensor_scale() {
let sparse = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 2.0), (1, 1, 3.0)],
));
let scaled = sparse.scale(2.0);
assert_eq!(scaled.get(&[0, 0]), 4.0);
assert_eq!(scaled.get(&[1, 1]), 6.0);
}
#[test]
fn test_sparse_tensor_map() {
let sparse = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![2, 2],
vec![(0, 0, 4.0), (1, 1, 9.0)],
));
let mapped = sparse.map(|x| x.sqrt());
assert!((mapped.get(&[0, 0]) - 2.0).abs() < 1e-10);
assert!((mapped.get(&[1, 1]) - 3.0).abs() < 1e-10);
}
#[test]
fn test_sparse_tensor_eye() {
let eye = SparseTensor::eye(3);
assert_eq!(eye.shape(), &[3, 3]);
assert_eq!(eye.nnz(), 3);
assert_eq!(eye.get(&[0, 0]), 1.0);
assert_eq!(eye.get(&[1, 1]), 1.0);
assert_eq!(eye.get(&[2, 2]), 1.0);
assert_eq!(eye.get(&[0, 1]), 0.0);
}
#[test]
fn test_sparse_stats() {
let sparse = SparseTensor::from_coo(SparseCOO::from_triplets(
vec![10, 10],
vec![(0, 0, 1.0), (5, 5, 2.0)],
));
let stats = SparseStats::from_tensor(&sparse);
assert_eq!(stats.nnz, 2);
assert_eq!(stats.numel, 100);
assert!((stats.sparsity - 0.98).abs() < 0.01);
assert!(stats.savings > 0.5); }
#[test]
fn test_sparse_config() -> Result<(), Box<dyn std::error::Error>> {
let config = SparseConfig::default();
let sparse_data = Array::from_shape_vec(
IxDyn(&[10]),
vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
)
.map_err(|e| format!("{e}"))?;
assert!(config.should_use_sparse(&sparse_data));
let dense_data =
Array::from_shape_vec(IxDyn(&[4]), vec![1.0, 0.0, 2.0, 0.0])
.map_err(|e| format!("{e}"))?;
assert!(!config.should_use_sparse(&dense_data));
Ok(())
}
}