use crate::sparse::core::SparseTensor;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn sparse_sum(sparse: &SparseTensor, dim: Option<usize>) -> TorshResult<Tensor> {
match dim {
None => {
let values = sparse.values.to_vec()?;
let total_sum: f32 = values.iter().sum();
Tensor::from_data(vec![total_sum], vec![1], sparse.values.device())
}
Some(d) => {
if d >= sparse.ndim {
return Err(TorshError::invalid_argument_with_context(
"Dimension out of range",
"sparse_sum",
));
}
let values = sparse.values.to_vec()?;
let indices = sparse.indices.to_vec()?;
let mut result_shape = sparse.shape.clone();
result_shape.remove(d);
if result_shape.is_empty() {
result_shape.push(1);
}
let result_size: usize = result_shape.iter().product();
let mut result_data = vec![0.0f32; result_size];
for i in 0..sparse.nnz {
let mut coords = Vec::new();
for j in 0..sparse.ndim {
coords.push(indices[j * sparse.nnz + i] as usize);
}
coords.remove(d);
let mut flat_idx = 0;
let mut stride = 1;
for j in (0..coords.len()).rev() {
flat_idx += coords[j] * stride;
stride *= result_shape[j];
}
result_data[flat_idx] += values[i];
}
Tensor::from_data(result_data, result_shape, sparse.values.device())
}
}
}
pub fn sparse_mean(sparse: &SparseTensor, dim: Option<usize>) -> TorshResult<Tensor> {
match dim {
None => {
let values = sparse.values.to_vec()?;
let total_sum: f32 = values.iter().sum();
let total_elements: usize = sparse.shape.iter().product();
let mean = total_sum / total_elements as f32;
Tensor::from_data(vec![mean], vec![1], sparse.values.device())
}
Some(d) => {
let sum_result = sparse_sum(sparse, Some(d))?;
let dim_size = sparse.shape[d] as f32;
sum_result.div_scalar(dim_size)
}
}
}
pub fn sparse_max(sparse: &SparseTensor, dim: Option<usize>) -> TorshResult<Tensor> {
match dim {
None => {
let values = sparse.values.to_vec()?;
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let final_max = max_val.max(0.0);
Tensor::from_data(vec![final_max], vec![1], sparse.values.device())
}
Some(d) => {
if d >= sparse.ndim {
return Err(TorshError::invalid_argument_with_context(
"Dimension out of range",
"sparse_max",
));
}
let values = sparse.values.to_vec()?;
let indices = sparse.indices.to_vec()?;
let mut result_shape = sparse.shape.clone();
result_shape.remove(d);
if result_shape.is_empty() {
result_shape.push(1);
}
let result_size: usize = result_shape.iter().product();
let mut result_data = vec![f32::NEG_INFINITY; result_size];
for i in 0..sparse.nnz {
let mut coords = Vec::new();
for j in 0..sparse.ndim {
coords.push(indices[j * sparse.nnz + i] as usize);
}
coords.remove(d);
let mut flat_idx = 0;
let mut stride = 1;
for j in (0..coords.len()).rev() {
flat_idx += coords[j] * stride;
stride *= result_shape[j];
}
result_data[flat_idx] = result_data[flat_idx].max(values[i]);
}
for val in &mut result_data {
*val = val.max(0.0);
}
Tensor::from_data(result_data, result_shape, sparse.values.device())
}
}
}
pub fn sparse_min(sparse: &SparseTensor, dim: Option<usize>) -> TorshResult<Tensor> {
match dim {
None => {
let values = sparse.values.to_vec()?;
let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let final_min = min_val.min(0.0);
Tensor::from_data(vec![final_min], vec![1], sparse.values.device())
}
Some(d) => {
if d >= sparse.ndim {
return Err(TorshError::invalid_argument_with_context(
"Dimension out of range",
"sparse_min",
));
}
let values = sparse.values.to_vec()?;
let indices = sparse.indices.to_vec()?;
let mut result_shape = sparse.shape.clone();
result_shape.remove(d);
if result_shape.is_empty() {
result_shape.push(1);
}
let result_size: usize = result_shape.iter().product();
let mut result_data = vec![f32::INFINITY; result_size];
for i in 0..sparse.nnz {
let mut coords = Vec::new();
for j in 0..sparse.ndim {
coords.push(indices[j * sparse.nnz + i] as usize);
}
coords.remove(d);
let mut flat_idx = 0;
let mut stride = 1;
for j in (0..coords.len()).rev() {
flat_idx += coords[j] * stride;
stride *= result_shape[j];
}
result_data[flat_idx] = result_data[flat_idx].min(values[i]);
}
for val in &mut result_data {
*val = val.min(0.0);
}
Tensor::from_data(result_data, result_shape, sparse.values.device())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse::core::sparse_coo_tensor;
#[test]
fn test_sparse_sum() -> TorshResult<()> {
let values = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], torsh_core::DeviceType::Cpu)?;
let indices = Tensor::from_data(
vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
vec![2, 3],
torsh_core::DeviceType::Cpu,
)?;
let shape = vec![3, 3];
let sparse = sparse_coo_tensor(&indices, &values, &shape)?;
let total_sum = sparse_sum(&sparse, None)?;
let sum_data = total_sum.to_vec()?;
assert!((sum_data[0] - 6.0).abs() < 1e-6);
let sum_dim0 = sparse_sum(&sparse, Some(0))?;
let sum_dim0_data = sum_dim0.to_vec()?;
assert_eq!(sum_dim0_data.len(), 3);
assert!((sum_dim0_data[0] - 1.0).abs() < 1e-6);
assert!((sum_dim0_data[1] - 2.0).abs() < 1e-6);
assert!((sum_dim0_data[2] - 3.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_sparse_mean() -> TorshResult<()> {
let values = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], torsh_core::DeviceType::Cpu)?;
let indices = Tensor::from_data(
vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
vec![2, 3],
torsh_core::DeviceType::Cpu,
)?;
let shape = vec![3, 3];
let sparse = sparse_coo_tensor(&indices, &values, &shape)?;
let total_mean = sparse_mean(&sparse, None)?;
let mean_data = total_mean.to_vec()?;
assert!((mean_data[0] - (6.0 / 9.0)).abs() < 1e-6);
let mean_dim0 = sparse_mean(&sparse, Some(0))?;
let mean_dim0_data = mean_dim0.to_vec()?;
assert_eq!(mean_dim0_data.len(), 3);
assert!((mean_dim0_data[0] - (1.0 / 3.0)).abs() < 1e-6);
assert!((mean_dim0_data[1] - (2.0 / 3.0)).abs() < 1e-6);
assert!((mean_dim0_data[2] - (3.0 / 3.0)).abs() < 1e-6);
Ok(())
}
#[test]
fn test_sparse_max() -> TorshResult<()> {
let values = Tensor::from_data(vec![1.0, -2.0, 3.0], vec![3], torsh_core::DeviceType::Cpu)?;
let indices = Tensor::from_data(
vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
vec![2, 3],
torsh_core::DeviceType::Cpu,
)?;
let shape = vec![3, 3];
let sparse = sparse_coo_tensor(&indices, &values, &shape)?;
let total_max = sparse_max(&sparse, None)?;
let max_data = total_max.to_vec()?;
assert!((max_data[0] - 3.0).abs() < 1e-6);
let max_dim0 = sparse_max(&sparse, Some(0))?;
let max_dim0_data = max_dim0.to_vec()?;
assert_eq!(max_dim0_data.len(), 3);
assert!((max_dim0_data[0] - 1.0).abs() < 1e-6);
assert!((max_dim0_data[1] - 0.0).abs() < 1e-6); assert!((max_dim0_data[2] - 3.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_sparse_min() -> TorshResult<()> {
let values = Tensor::from_data(vec![1.0, -2.0, 3.0], vec![3], torsh_core::DeviceType::Cpu)?;
let indices = Tensor::from_data(
vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
vec![2, 3],
torsh_core::DeviceType::Cpu,
)?;
let shape = vec![3, 3];
let sparse = sparse_coo_tensor(&indices, &values, &shape)?;
let total_min = sparse_min(&sparse, None)?;
let min_data = total_min.to_vec()?;
assert!((min_data[0] - (-2.0)).abs() < 1e-6);
let min_dim0 = sparse_min(&sparse, Some(0))?;
let min_dim0_data = min_dim0.to_vec()?;
assert_eq!(min_dim0_data.len(), 3);
assert!((min_dim0_data[0] - 0.0).abs() < 1e-6); assert!((min_dim0_data[1] - (-2.0)).abs() < 1e-6);
assert!((min_dim0_data[2] - 0.0).abs() < 1e-6);
Ok(())
}
}