use crate::{CooTensor, SparseTensor, TorshResult};
use std::collections::HashMap;
use torsh_core::TorshError;
use torsh_tensor::{creation::zeros, Tensor};
mod utils {
use super::*;
pub fn validate_same_shape(a: &dyn SparseTensor, b: &dyn SparseTensor) -> TorshResult<()> {
if a.shape() != b.shape() {
return Err(TorshError::InvalidArgument(format!(
"Shape mismatch: {:?} vs {:?}",
a.shape(),
b.shape()
)));
}
Ok(())
}
pub fn validate_matmul_dims(a: &dyn SparseTensor, b: &Tensor) -> TorshResult<()> {
if b.shape().ndim() != 2 {
return Err(TorshError::InvalidArgument(
"Right operand must be a 2D tensor".to_string(),
));
}
if a.shape().dims()[1] != b.shape().dims()[0] {
return Err(TorshError::InvalidArgument(format!(
"Dimension mismatch: [{} x {}] @ [{} x {}]",
a.shape().dims()[0],
a.shape().dims()[1],
b.shape().dims()[0],
b.shape().dims()[1]
)));
}
Ok(())
}
pub fn validate_square(tensor: &dyn SparseTensor) -> TorshResult<()> {
let shape = tensor.shape();
if shape.dims()[0] != shape.dims()[1] {
return Err(TorshError::InvalidArgument(
"Matrix must be square".to_string(),
));
}
Ok(())
}
pub fn validate_axis(axis: usize) -> TorshResult<()> {
if axis > 1 {
return Err(TorshError::InvalidArgument(
"Axis must be 0 (rows) or 1 (columns)".to_string(),
));
}
Ok(())
}
pub fn to_coo_safe(tensor: &dyn SparseTensor) -> TorshResult<CooTensor> {
tensor.to_coo()
}
pub fn create_position_map(coo: &CooTensor) -> HashMap<(usize, usize), f32> {
let mut map = HashMap::new();
for (row, col, val) in coo.triplets() {
map.insert((row, col), val);
}
map
}
pub fn extract_filtered_triplets(
triplets: Vec<(usize, usize, f32)>,
threshold: f32,
) -> (Vec<usize>, Vec<usize>, Vec<f32>) {
triplets
.into_iter()
.filter(|(_, _, v)| v.abs() > threshold)
.fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), (r, c, v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
)
}
pub fn element_wise_operation<F>(
a: &dyn SparseTensor,
b: &dyn SparseTensor,
op: F,
) -> TorshResult<CooTensor>
where
F: Fn(f32, f32) -> f32,
{
validate_same_shape(a, b)?;
let a_coo = to_coo_safe(a)?;
let b_coo = to_coo_safe(b)?;
let b_map = create_position_map(&b_coo);
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for (row, col, a_val) in a_coo.triplets() {
if let Some(&b_val) = b_map.get(&(row, col)) {
let result = op(a_val, b_val);
if result.abs() > f32::EPSILON {
row_indices.push(row);
col_indices.push(col);
values.push(result);
}
}
}
CooTensor::new(row_indices, col_indices, values, a.shape().clone())
}
pub fn reduce_operation<F>(
tensor: &dyn SparseTensor,
axis: Option<usize>,
op: F,
) -> TorshResult<Vec<f32>>
where
F: Fn(&mut Vec<f32>, usize, usize, f32),
{
let coo = to_coo_safe(tensor)?;
let shape = tensor.shape();
match axis {
None => {
let mut result = vec![0.0; 1];
for (row, col, val) in coo.triplets() {
op(&mut result, row, col, val);
}
Ok(result)
}
Some(0) => {
let cols = shape.dims()[1];
let mut result = vec![0.0; cols];
for (row, col, val) in coo.triplets() {
op(&mut result, row, col, val);
}
Ok(result)
}
Some(1) => {
let rows = shape.dims()[0];
let mut result = vec![0.0; rows];
for (row, col, val) in coo.triplets() {
op(&mut result, row, col, val);
}
Ok(result)
}
Some(invalid_axis) => {
validate_axis(invalid_axis)?;
panic!("Invalid axis: {invalid_axis}. This should be unreachable after validate_axis check.")
}
}
}
}
pub fn spmm(a: &dyn SparseTensor, b: &Tensor) -> TorshResult<Tensor> {
utils::validate_matmul_dims(a, b)?;
let a_csr = a.to_csr()?;
let m = a.shape().dims()[0];
let n = b.shape().dims()[1];
let result = zeros::<f32>(&[m, n])?;
for i in 0..m {
let (cols, vals) = a_csr.get_row(i)?;
for j in 0..n {
let mut sum = 0.0;
for (k, &col) in cols.iter().enumerate() {
sum += vals[k] * b.get(&[col, j])?;
}
result.set(&[i, j], sum)?;
}
}
Ok(result)
}
pub fn spadd(a: &dyn SparseTensor, b: &dyn SparseTensor, alpha: f32) -> TorshResult<CooTensor> {
utils::validate_same_shape(a, b)?;
let a_coo = utils::to_coo_safe(a)?;
let b_coo = utils::to_coo_safe(b)?;
let mut triplets = a_coo.triplets();
for (row, col, val) in b_coo.triplets() {
let mut found = false;
for t in triplets.iter_mut() {
if t.0 == row && t.1 == col {
t.2 += alpha * val;
found = true;
break;
}
}
if !found {
triplets.push((row, col, alpha * val));
}
}
let (row_indices, col_indices, values) =
utils::extract_filtered_triplets(triplets, f32::EPSILON);
CooTensor::new(row_indices, col_indices, values, a.shape().clone())
}
pub fn sphadamard(a: &dyn SparseTensor, b: &dyn SparseTensor) -> TorshResult<CooTensor> {
utils::element_wise_operation(a, b, |a_val, b_val| a_val * b_val)
}
pub fn transpose(tensor: &dyn SparseTensor) -> TorshResult<CooTensor> {
let coo = utils::to_coo_safe(tensor)?;
Ok(coo.transpose())
}
pub fn sum(tensor: &dyn SparseTensor) -> TorshResult<f32> {
let result = utils::reduce_operation(tensor, None, |result, _, _, val| {
result[0] += val;
})?;
Ok(result[0])
}
pub fn sum_axis(tensor: &dyn SparseTensor, axis: usize) -> TorshResult<Vec<f32>> {
utils::validate_axis(axis)?;
utils::reduce_operation(tensor, Some(axis), |result, row, col, val| {
match axis {
0 => result[col] += val, 1 => result[row] += val, _ => panic!(
"Invalid axis: {axis}. This should be unreachable after validate_axis check."
),
}
})
}
pub fn norm(tensor: &dyn SparseTensor) -> TorshResult<f32> {
let result = utils::reduce_operation(tensor, None, |result, _, _, val| {
result[0] += val * val;
})?;
Ok(result[0].sqrt())
}
pub fn scale(tensor: &dyn SparseTensor, scalar: f32) -> TorshResult<CooTensor> {
let coo = utils::to_coo_safe(tensor)?;
let triplets: Vec<_> = coo
.triplets()
.into_iter()
.map(|(r, c, v)| (r, c, v * scalar))
.collect();
let (row_indices, col_indices, values) = utils::extract_filtered_triplets(triplets, 0.0);
CooTensor::new(row_indices, col_indices, values, tensor.shape().clone())
}
pub fn diag(tensor: &dyn SparseTensor) -> TorshResult<Vec<f32>> {
utils::validate_square(tensor)?;
let n = tensor.shape().dims()[0];
let mut diagonal = vec![0.0; n];
let coo = utils::to_coo_safe(tensor)?;
for (row, col, val) in coo.triplets() {
if row == col {
diagonal[row] = val;
}
}
Ok(diagonal)
}
pub fn element_add(a: &dyn SparseTensor, b: &dyn SparseTensor) -> TorshResult<CooTensor> {
utils::element_wise_operation(a, b, |a_val, b_val| a_val + b_val)
}
pub fn element_sub(a: &dyn SparseTensor, b: &dyn SparseTensor) -> TorshResult<CooTensor> {
utils::element_wise_operation(a, b, |a_val, b_val| a_val - b_val)
}
pub fn element_div(a: &dyn SparseTensor, b: &dyn SparseTensor) -> TorshResult<CooTensor> {
utils::element_wise_operation(a, b, |a_val, b_val| {
if b_val.abs() < f32::EPSILON {
0.0 } else {
a_val / b_val
}
})
}
pub fn mean_axis(tensor: &dyn SparseTensor, axis: usize) -> TorshResult<Vec<f32>> {
let sums = sum_axis(tensor, axis)?;
let shape = tensor.shape();
let divisor = if axis == 0 {
shape.dims()[0]
} else {
shape.dims()[1]
};
Ok(sums.into_iter().map(|s| s / divisor as f32).collect())
}
pub fn var_axis(tensor: &dyn SparseTensor, axis: usize) -> TorshResult<Vec<f32>> {
let means = mean_axis(tensor, axis)?;
utils::validate_axis(axis)?;
let result = utils::reduce_operation(tensor, Some(axis), |result, row, col, val| {
let idx = if axis == 0 { col } else { row };
let diff = val - means[idx];
result[idx] += diff * diff;
})?;
let shape = tensor.shape();
let divisor = if axis == 0 {
shape.dims()[0]
} else {
shape.dims()[1]
};
Ok(result.into_iter().map(|v| v / divisor as f32).collect())
}
pub fn max_abs(tensor: &dyn SparseTensor) -> TorshResult<f32> {
let result = utils::reduce_operation(tensor, None, |result, _, _, val| {
result[0] = result[0].max(val.abs());
})?;
Ok(result[0])
}
pub fn min(tensor: &dyn SparseTensor) -> TorshResult<f32> {
let coo = utils::to_coo_safe(tensor)?;
let triplets = coo.triplets();
if triplets.is_empty() {
return Ok(0.0); }
let total_elements = tensor.shape().dims().iter().product::<usize>();
let nnz = triplets.len();
let mut min_val = triplets[0].2;
for (_, _, val) in &triplets {
min_val = min_val.min(*val);
}
if nnz < total_elements && min_val > 0.0 {
Ok(0.0)
} else {
Ok(min_val)
}
}
pub fn max(tensor: &dyn SparseTensor) -> TorshResult<f32> {
let coo = utils::to_coo_safe(tensor)?;
let triplets = coo.triplets();
if triplets.is_empty() {
return Ok(0.0); }
let total_elements = tensor.shape().dims().iter().product::<usize>();
let nnz = triplets.len();
let mut max_val = triplets[0].2;
for (_, _, val) in &triplets {
max_val = max_val.max(*val);
}
if nnz < total_elements && max_val < 0.0 {
Ok(0.0)
} else {
Ok(max_val)
}
}
pub fn mean(tensor: &dyn SparseTensor) -> TorshResult<f32> {
let total_sum = sum(tensor)?;
let total_elements = tensor.shape().dims().iter().product::<usize>();
Ok(total_sum / total_elements as f32)
}
pub fn std(tensor: &dyn SparseTensor, ddof: usize) -> TorshResult<f32> {
let mean_val = mean(tensor)?;
let coo = utils::to_coo_safe(tensor)?;
let triplets = coo.triplets();
let total_elements = tensor.shape().dims().iter().product::<usize>();
let nnz = triplets.len();
let mut sum_sq_diff = 0.0;
for (_, _, val) in &triplets {
let diff = val - mean_val;
sum_sq_diff += diff * diff;
}
let num_zeros = total_elements - nnz;
sum_sq_diff += num_zeros as f32 * mean_val * mean_val;
let variance = sum_sq_diff / (total_elements - ddof) as f32;
Ok(variance.sqrt())
}
pub fn std_axis(tensor: &dyn SparseTensor, axis: usize, ddof: usize) -> TorshResult<Vec<f32>> {
utils::validate_axis(axis)?;
let variances = var_axis(tensor, axis)?;
let adjustment_factor = if ddof > 0 {
let n = if axis == 0 {
tensor.shape().dims()[0]
} else {
tensor.shape().dims()[1]
};
(n as f32) / ((n - ddof) as f32)
} else {
1.0
};
Ok(variances
.iter()
.map(|v| (v * adjustment_factor).sqrt())
.collect())
}
pub fn addmm(
c: &dyn SparseTensor,
a: &dyn SparseTensor,
b: &dyn SparseTensor,
alpha: f32,
beta: f32,
) -> TorshResult<CooTensor> {
let ab = sparse_matmul_optimized(a, b)?;
let scaled_ab = scale(&ab as &dyn SparseTensor, alpha)?;
let scaled_c = scale(c, beta)?;
element_add(
&scaled_ab as &dyn SparseTensor,
&scaled_c as &dyn SparseTensor,
)
}
pub fn sparse_softmax(tensor: &dyn SparseTensor, axis: usize) -> TorshResult<CooTensor> {
utils::validate_axis(axis)?;
let coo = utils::to_coo_safe(tensor)?;
let triplets = coo.triplets();
if axis == 1 {
let mut row_max: HashMap<usize, f32> = HashMap::new();
let mut row_sum_exp: HashMap<usize, f32> = HashMap::new();
for (row, _, val) in &triplets {
row_max
.entry(*row)
.and_modify(|m| *m = m.max(*val))
.or_insert(*val);
}
let mut new_triplets = Vec::new();
for (row, col, val) in &triplets {
let max_val = row_max.get(row).unwrap_or(&0.0);
let exp_val = (val - max_val).exp();
new_triplets.push((*row, *col, exp_val));
row_sum_exp
.entry(*row)
.and_modify(|s| *s += exp_val)
.or_insert(exp_val);
}
let final_triplets: Vec<(usize, usize, f32)> = new_triplets
.into_iter()
.map(|(row, col, exp_val)| {
let sum = row_sum_exp.get(&row).unwrap_or(&1.0);
(row, col, exp_val / sum)
})
.collect();
CooTensor::from_triplets(
final_triplets,
(tensor.shape().dims()[0], tensor.shape().dims()[1]),
)
} else {
let mut col_max: HashMap<usize, f32> = HashMap::new();
let mut col_sum_exp: HashMap<usize, f32> = HashMap::new();
for (_, col, val) in &triplets {
col_max
.entry(*col)
.and_modify(|m| *m = m.max(*val))
.or_insert(*val);
}
let mut new_triplets = Vec::new();
for (row, col, val) in &triplets {
let max_val = col_max.get(col).unwrap_or(&0.0);
let exp_val = (val - max_val).exp();
new_triplets.push((*row, *col, exp_val));
col_sum_exp
.entry(*col)
.and_modify(|s| *s += exp_val)
.or_insert(exp_val);
}
let final_triplets: Vec<(usize, usize, f32)> = new_triplets
.into_iter()
.map(|(row, col, exp_val)| {
let sum = col_sum_exp.get(&col).unwrap_or(&1.0);
(row, col, exp_val / sum)
})
.collect();
CooTensor::from_triplets(
final_triplets,
(tensor.shape().dims()[0], tensor.shape().dims()[1]),
)
}
}
pub fn sparse_log_softmax(tensor: &dyn SparseTensor, axis: usize) -> TorshResult<CooTensor> {
utils::validate_axis(axis)?;
let coo = utils::to_coo_safe(tensor)?;
let triplets = coo.triplets();
if axis == 1 {
let mut row_max: HashMap<usize, f32> = HashMap::new();
let mut row_log_sum_exp: HashMap<usize, f32> = HashMap::new();
for (row, _, val) in &triplets {
row_max
.entry(*row)
.and_modify(|m| *m = m.max(*val))
.or_insert(*val);
}
let mut exp_values = Vec::new();
for (row, col, val) in &triplets {
let max_val = row_max.get(row).unwrap_or(&0.0);
let exp_val = (val - max_val).exp();
exp_values.push((*row, *col, exp_val, *val, *max_val));
row_log_sum_exp
.entry(*row)
.and_modify(|s| *s += exp_val)
.or_insert(exp_val);
}
let mut row_lse: HashMap<usize, f32> = HashMap::new();
for (row, sum_exp) in row_log_sum_exp {
let max_val = row_max.get(&row).unwrap_or(&0.0);
row_lse.insert(row, max_val + sum_exp.ln());
}
let final_triplets: Vec<(usize, usize, f32)> = exp_values
.into_iter()
.map(|(row, col, _, val, _)| {
let lse = row_lse.get(&row).unwrap_or(&0.0);
(row, col, val - lse)
})
.collect();
CooTensor::from_triplets(
final_triplets,
(tensor.shape().dims()[0], tensor.shape().dims()[1]),
)
} else {
let mut col_max: HashMap<usize, f32> = HashMap::new();
let mut col_log_sum_exp: HashMap<usize, f32> = HashMap::new();
for (_, col, val) in &triplets {
col_max
.entry(*col)
.and_modify(|m| *m = m.max(*val))
.or_insert(*val);
}
let mut exp_values = Vec::new();
for (row, col, val) in &triplets {
let max_val = col_max.get(col).unwrap_or(&0.0);
let exp_val = (val - max_val).exp();
exp_values.push((*row, *col, exp_val, *val, *max_val));
col_log_sum_exp
.entry(*col)
.and_modify(|s| *s += exp_val)
.or_insert(exp_val);
}
let mut col_lse: HashMap<usize, f32> = HashMap::new();
for (col, sum_exp) in col_log_sum_exp {
let max_val = col_max.get(&col).unwrap_or(&0.0);
col_lse.insert(col, max_val + sum_exp.ln());
}
let final_triplets: Vec<(usize, usize, f32)> = exp_values
.into_iter()
.map(|(row, col, _, val, _)| {
let lse = col_lse.get(&col).unwrap_or(&0.0);
(row, col, val - lse)
})
.collect();
CooTensor::from_triplets(
final_triplets,
(tensor.shape().dims()[0], tensor.shape().dims()[1]),
)
}
}
pub fn sparse_matmul(a: &dyn SparseTensor, b: &dyn SparseTensor) -> TorshResult<CooTensor> {
sparse_matmul_optimized(a, b)
}
pub fn sparse_matmul_optimized(
a: &dyn SparseTensor,
b: &dyn SparseTensor,
) -> TorshResult<CooTensor> {
if a.shape().dims()[1] != b.shape().dims()[0] {
return Err(TorshError::InvalidArgument(format!(
"Dimension mismatch: [{} x {}] @ [{} x {}]",
a.shape().dims()[0],
a.shape().dims()[1],
b.shape().dims()[0],
b.shape().dims()[1]
)));
}
let a_csr = a.to_csr()?;
let b_csc = b.to_csc()?;
let m = a.shape().dims()[0]; let n = b.shape().dims()[1];
let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
for i in 0..m {
let (a_cols, a_vals) = a_csr.get_row(i)?;
for j in 0..n {
let (b_rows, b_vals) = b_csc.get_col(j)?;
let mut dot_product = 0.0;
let (mut a_idx, mut b_idx) = (0, 0);
while a_idx < a_cols.len() && b_idx < b_rows.len() {
if a_cols[a_idx] == b_rows[b_idx] {
dot_product += a_vals[a_idx] * b_vals[b_idx];
a_idx += 1;
b_idx += 1;
} else if a_cols[a_idx] < b_rows[b_idx] {
a_idx += 1;
} else {
b_idx += 1;
}
}
if dot_product.abs() > f32::EPSILON {
result_map.insert((i, j), dot_product);
}
}
}
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) = result_map.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), ((r, c), v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
);
let result_shape = crate::Shape::new(vec![m, n]);
CooTensor::new(row_indices, col_indices, values, result_shape)
}
pub fn sparse_matmul_blocked(
a: &dyn SparseTensor,
b: &dyn SparseTensor,
block_size: Option<usize>,
) -> TorshResult<CooTensor> {
let block_size = block_size.unwrap_or(64);
if a.shape().dims()[1] != b.shape().dims()[0] {
return Err(TorshError::InvalidArgument(format!(
"Dimension mismatch: [{} x {}] @ [{} x {}]",
a.shape().dims()[0],
a.shape().dims()[1],
b.shape().dims()[0],
b.shape().dims()[1]
)));
}
let m = a.shape().dims()[0];
let n = b.shape().dims()[1];
let k = a.shape().dims()[1];
let a_csr = a.to_csr()?;
let b_csc = b.to_csc()?;
let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
for i_block in (0..m).step_by(block_size) {
for j_block in (0..n).step_by(block_size) {
for k_block in (0..k).step_by(block_size) {
let i_end = std::cmp::min(i_block + block_size, m);
let j_end = std::cmp::min(j_block + block_size, n);
let k_end = std::cmp::min(k_block + block_size, k);
for i in i_block..i_end {
let (a_cols, a_vals) = a_csr.get_row(i)?;
for j in j_block..j_end {
let (b_rows, b_vals) = b_csc.get_col(j)?;
let mut partial_sum = 0.0;
let (mut a_idx, mut b_idx) = (0, 0);
while a_idx < a_cols.len() && b_idx < b_rows.len() {
let a_col = a_cols[a_idx];
let b_row = b_rows[b_idx];
if a_col < k_block || b_row < k_block {
if a_col < b_row {
a_idx += 1;
} else {
b_idx += 1;
}
continue;
}
if a_col >= k_end || b_row >= k_end {
break;
}
if a_col == b_row {
partial_sum += a_vals[a_idx] * b_vals[b_idx];
a_idx += 1;
b_idx += 1;
} else if a_col < b_row {
a_idx += 1;
} else {
b_idx += 1;
}
}
if partial_sum.abs() > f32::EPSILON {
*result_map.entry((i, j)).or_insert(0.0) += partial_sum;
}
}
}
}
}
}
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) = result_map
.into_iter()
.filter(|(_, v)| v.abs() > f32::EPSILON)
.fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), ((r, c), v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
);
let result_shape = crate::Shape::new(vec![m, n]);
CooTensor::new(row_indices, col_indices, values, result_shape)
}
pub fn triangular_solve(
a: &dyn SparseTensor,
b: &Tensor,
upper: bool,
transpose: bool,
) -> TorshResult<Tensor> {
utils::validate_square(a)?;
let n = a.shape().dims()[0];
if b.shape().dims()[0] != n {
return Err(TorshError::InvalidArgument(format!(
"Dimension mismatch: matrix size {} but RHS size {}",
n,
b.shape().dims()[0]
)));
}
let is_vector = b.shape().ndim() == 1;
let nrhs = if is_vector { 1 } else { b.shape().dims()[1] };
let a_csr = a.to_csr()?;
let x = if is_vector {
zeros::<f32>(&[n])?
} else {
zeros::<f32>(&[n, nrhs])?
};
match (upper, transpose) {
(false, false) => {
for i in 0..n {
for j in 0..nrhs {
let mut sum = if is_vector {
b.get(&[i])?
} else {
b.get(&[i, j])?
};
let (cols, vals) = a_csr.get_row(i)?;
for (k, &col) in cols.iter().enumerate() {
if col < i {
let x_val = if is_vector {
x.get(&[col])?
} else {
x.get(&[col, j])?
};
sum -= vals[k] * x_val;
} else if col == i {
if vals[k].abs() < f32::EPSILON {
return Err(TorshError::ComputeError(
"Singular matrix: zero diagonal element".to_string(),
));
}
sum /= vals[k];
break;
}
}
if is_vector {
x.set(&[i], sum)?;
} else {
x.set(&[i, j], sum)?;
}
}
}
}
(true, false) => {
for i in (0..n).rev() {
for j in 0..nrhs {
let mut sum = if is_vector {
b.get(&[i])?
} else {
b.get(&[i, j])?
};
let (cols, vals) = a_csr.get_row(i)?;
for (k, &col) in cols.iter().enumerate() {
if col > i {
let x_val = if is_vector {
x.get(&[col])?
} else {
x.get(&[col, j])?
};
sum -= vals[k] * x_val;
} else if col == i {
if vals[k].abs() < f32::EPSILON {
return Err(TorshError::ComputeError(
"Singular matrix: zero diagonal element".to_string(),
));
}
sum /= vals[k];
}
}
if is_vector {
x.set(&[i], sum)?;
} else {
x.set(&[i, j], sum)?;
}
}
}
}
(false, true) => {
let a_csc = a.to_csc()?;
for i in (0..n).rev() {
for j in 0..nrhs {
let mut sum = if is_vector {
b.get(&[i])?
} else {
b.get(&[i, j])?
};
let (rows, vals) = a_csc.get_col(i)?;
for (k, &row) in rows.iter().enumerate() {
if row > i {
let x_val = if is_vector {
x.get(&[row])?
} else {
x.get(&[row, j])?
};
sum -= vals[k] * x_val;
} else if row == i {
if vals[k].abs() < f32::EPSILON {
return Err(TorshError::ComputeError(
"Singular matrix: zero diagonal element".to_string(),
));
}
sum /= vals[k];
}
}
if is_vector {
x.set(&[i], sum)?;
} else {
x.set(&[i, j], sum)?;
}
}
}
}
(true, true) => {
let a_csc = a.to_csc()?;
for i in 0..n {
for j in 0..nrhs {
let mut sum = if is_vector {
b.get(&[i])?
} else {
b.get(&[i, j])?
};
let (rows, vals) = a_csc.get_col(i)?;
for (k, &row) in rows.iter().enumerate() {
if row < i {
let x_val = if is_vector {
x.get(&[row])?
} else {
x.get(&[row, j])?
};
sum -= vals[k] * x_val;
} else if row == i {
if vals[k].abs() < f32::EPSILON {
return Err(TorshError::ComputeError(
"Singular matrix: zero diagonal element".to_string(),
));
}
sum /= vals[k];
break;
}
}
if is_vector {
x.set(&[i], sum)?;
} else {
x.set(&[i, j], sum)?;
}
}
}
}
}
Ok(x)
}
pub fn addcmul(
input: &dyn SparseTensor,
tensor1: &dyn SparseTensor,
tensor2: &dyn SparseTensor,
value: f32,
) -> TorshResult<CooTensor> {
utils::validate_same_shape(input, tensor1)?;
utils::validate_same_shape(input, tensor2)?;
let input_coo = utils::to_coo_safe(input)?;
let tensor1_coo = utils::to_coo_safe(tensor1)?;
let tensor2_coo = utils::to_coo_safe(tensor2)?;
let input_map = utils::create_position_map(&input_coo);
let tensor1_map = utils::create_position_map(&tensor1_coo);
let tensor2_map = utils::create_position_map(&tensor2_coo);
let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
for ((row, col), val) in input_map {
result_map.insert((row, col), val);
}
for ((row, col), val1) in tensor1_map {
if let Some(&val2) = tensor2_map.get(&(row, col)) {
let product = value * val1 * val2;
*result_map.entry((row, col)).or_insert(0.0) += product;
}
}
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) = result_map
.into_iter()
.filter(|(_, v)| v.abs() > f32::EPSILON)
.fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), ((r, c), v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
);
CooTensor::new(row_indices, col_indices, values, input.shape().clone())
}
pub fn addcdiv(
input: &dyn SparseTensor,
tensor1: &dyn SparseTensor,
tensor2: &dyn SparseTensor,
value: f32,
) -> TorshResult<CooTensor> {
utils::validate_same_shape(input, tensor1)?;
utils::validate_same_shape(input, tensor2)?;
let input_coo = utils::to_coo_safe(input)?;
let tensor1_coo = utils::to_coo_safe(tensor1)?;
let tensor2_coo = utils::to_coo_safe(tensor2)?;
let input_map = utils::create_position_map(&input_coo);
let tensor1_map = utils::create_position_map(&tensor1_coo);
let tensor2_map = utils::create_position_map(&tensor2_coo);
let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
for ((row, col), val) in input_map {
result_map.insert((row, col), val);
}
for ((row, col), val1) in tensor1_map {
if let Some(&val2) = tensor2_map.get(&(row, col)) {
if val2.abs() > f32::EPSILON {
let quotient = value * val1 / val2;
*result_map.entry((row, col)).or_insert(0.0) += quotient;
}
}
}
let (row_indices, col_indices, values): (Vec<_>, Vec<_>, Vec<_>) = result_map
.into_iter()
.filter(|(_, v)| v.abs() > f32::EPSILON)
.fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), ((r, c), v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
);
CooTensor::new(row_indices, col_indices, values, input.shape().clone())
}
pub fn masked_fill<F>(
tensor: &dyn SparseTensor,
condition: F,
fill_value: f32,
) -> TorshResult<CooTensor>
where
F: Fn(f32) -> bool,
{
let coo = utils::to_coo_safe(tensor)?;
let triplets: Vec<_> = coo
.triplets()
.into_iter()
.map(|(r, c, v)| {
if condition(v) {
(r, c, fill_value)
} else {
(r, c, v)
}
})
.collect();
let (row_indices, col_indices, values) =
utils::extract_filtered_triplets(triplets, f32::EPSILON);
CooTensor::new(row_indices, col_indices, values, tensor.shape().clone())
}
pub fn clamp(
tensor: &dyn SparseTensor,
min: Option<f32>,
max: Option<f32>,
) -> TorshResult<CooTensor> {
let coo = utils::to_coo_safe(tensor)?;
let triplets: Vec<_> = coo
.triplets()
.into_iter()
.map(|(r, c, mut v)| {
if let Some(min_val) = min {
v = v.max(min_val);
}
if let Some(max_val) = max {
v = v.min(max_val);
}
(r, c, v)
})
.collect();
let (row_indices, col_indices, values) =
utils::extract_filtered_triplets(triplets, f32::EPSILON);
CooTensor::new(row_indices, col_indices, values, tensor.shape().clone())
}
pub fn abs(tensor: &dyn SparseTensor) -> TorshResult<CooTensor> {
let coo = utils::to_coo_safe(tensor)?;
let triplets: Vec<_> = coo
.triplets()
.into_iter()
.map(|(r, c, v)| (r, c, v.abs()))
.collect();
let (row_indices, col_indices, values) = utils::extract_filtered_triplets(triplets, 0.0);
CooTensor::new(row_indices, col_indices, values, tensor.shape().clone())
}
pub fn sign(tensor: &dyn SparseTensor) -> TorshResult<CooTensor> {
let coo = utils::to_coo_safe(tensor)?;
let triplets: Vec<_> = coo
.triplets()
.into_iter()
.map(|(r, c, v)| {
let sign_val = if v > 0.0 {
1.0
} else if v < 0.0 {
-1.0
} else {
0.0
};
(r, c, sign_val)
})
.collect();
let (row_indices, col_indices, values) =
utils::extract_filtered_triplets(triplets, f32::EPSILON);
CooTensor::new(row_indices, col_indices, values, tensor.shape().clone())
}
pub fn pow(tensor: &dyn SparseTensor, exponent: f32) -> TorshResult<CooTensor> {
let coo = utils::to_coo_safe(tensor)?;
let triplets: Vec<_> = coo
.triplets()
.into_iter()
.map(|(r, c, v)| (r, c, v.powf(exponent)))
.collect();
let (row_indices, col_indices, values) =
utils::extract_filtered_triplets(triplets, f32::EPSILON);
CooTensor::new(row_indices, col_indices, values, tensor.shape().clone())
}
pub fn square(tensor: &dyn SparseTensor) -> TorshResult<CooTensor> {
pow(tensor, 2.0)
}
pub fn sqrt(tensor: &dyn SparseTensor) -> TorshResult<CooTensor> {
let coo = utils::to_coo_safe(tensor)?;
let triplets: Vec<_> = coo
.triplets()
.into_iter()
.map(|(r, c, v)| (r, c, v.sqrt()))
.collect();
let (row_indices, col_indices, values) =
utils::extract_filtered_triplets(triplets, f32::EPSILON);
CooTensor::new(row_indices, col_indices, values, tensor.shape().clone())
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::ones;
#[test]
fn test_spmm() {
let row_indices = vec![0, 1, 2];
let col_indices = vec![1, 2, 0];
let values = vec![1.0, 2.0, 3.0];
let shape = crate::Shape::new(vec![3, 3]);
let sparse = CooTensor::new(row_indices, col_indices, values, shape).unwrap();
let dense = ones::<f32>(&[3, 2]).unwrap();
let result = spmm(&sparse as &dyn SparseTensor, &dense).unwrap();
assert_eq!(result.shape().dims(), &[3, 2]);
}
#[test]
fn test_spadd() {
let a = CooTensor::new(
vec![0, 1],
vec![0, 1],
vec![1.0, 2.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let b = CooTensor::new(
vec![0, 1],
vec![1, 0],
vec![3.0, 4.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let result = spadd(&a as &dyn SparseTensor, &b as &dyn SparseTensor, 0.5).unwrap();
assert_eq!(result.nnz(), 4); }
#[test]
fn test_sum() {
let coo = CooTensor::new(
vec![0, 1, 2],
vec![0, 1, 2],
vec![1.0, 2.0, 3.0],
crate::Shape::new(vec![3, 3]),
)
.unwrap();
let total = sum(&coo as &dyn SparseTensor).unwrap();
assert_eq!(total, 6.0);
}
#[test]
fn test_sum_axis() {
let coo = CooTensor::new(
vec![0, 0, 1, 1],
vec![0, 1, 0, 1],
vec![1.0, 2.0, 3.0, 4.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let col_sums = sum_axis(&coo as &dyn SparseTensor, 0).unwrap();
assert_eq!(col_sums, vec![4.0, 6.0]);
let row_sums = sum_axis(&coo as &dyn SparseTensor, 1).unwrap();
assert_eq!(row_sums, vec![3.0, 7.0]); }
#[test]
fn test_norm() {
let coo = CooTensor::new(
vec![0, 1, 2],
vec![0, 1, 2],
vec![3.0, 4.0, 0.0],
crate::Shape::new(vec![3, 3]),
)
.unwrap();
let l2_norm = norm(&coo as &dyn SparseTensor).unwrap();
assert_eq!(l2_norm, 5.0); }
#[test]
fn test_scale() {
let coo = CooTensor::new(
vec![0, 1],
vec![0, 1],
vec![2.0, 4.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let scaled = scale(&coo as &dyn SparseTensor, 0.5).unwrap();
let triplets = scaled.triplets();
assert_eq!(triplets.len(), 2);
assert_eq!(triplets[0].2, 1.0); assert_eq!(triplets[1].2, 2.0); }
#[test]
fn test_diag() {
let coo = CooTensor::new(
vec![0, 0, 1, 1, 2],
vec![0, 1, 1, 2, 2],
vec![1.0, 2.0, 3.0, 4.0, 5.0],
crate::Shape::new(vec![3, 3]),
)
.unwrap();
let diagonal = diag(&coo as &dyn SparseTensor).unwrap();
assert_eq!(diagonal, vec![1.0, 3.0, 5.0]);
}
#[test]
fn test_triangular_solve_lower() {
let lower = CooTensor::new(
vec![0, 1, 1, 2, 2, 2],
vec![0, 0, 1, 0, 1, 2],
vec![2.0, 1.0, 3.0, 2.0, 1.0, 4.0],
crate::Shape::new(vec![3, 3]),
)
.unwrap();
let b = torsh_tensor::Tensor::from_vec(vec![2.0, 7.0, 15.0], &[3]).unwrap();
let x = triangular_solve(&lower as &dyn SparseTensor, &b, false, false).unwrap();
assert!((x.get(&[0]).unwrap() - 1.0).abs() < 1e-5);
assert!((x.get(&[1]).unwrap() - 2.0).abs() < 1e-5);
assert!((x.get(&[2]).unwrap() - 2.75).abs() < 1e-5);
}
#[test]
fn test_triangular_solve_upper() {
let upper = CooTensor::new(
vec![0, 0, 0, 1, 1, 2],
vec![0, 1, 2, 1, 2, 2],
vec![2.0, 1.0, 2.0, 3.0, 1.0, 4.0],
crate::Shape::new(vec![3, 3]),
)
.unwrap();
let b = torsh_tensor::Tensor::from_vec(vec![18.0, 15.0, 12.0], &[3]).unwrap();
let x = triangular_solve(&upper as &dyn SparseTensor, &b, true, false).unwrap();
assert!((x.get(&[0]).unwrap() - 1.0).abs() < 1e-5);
assert!((x.get(&[1]).unwrap() - 2.0).abs() < 1e-5);
assert!((x.get(&[2]).unwrap() - 3.0).abs() < 1e-5);
}
#[test]
fn test_addcmul() {
let input = CooTensor::new(
vec![0, 1],
vec![0, 1],
vec![1.0, 2.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let tensor1 = CooTensor::new(
vec![0, 0],
vec![0, 1],
vec![2.0, 3.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let tensor2 = CooTensor::new(
vec![0, 0],
vec![0, 1],
vec![4.0, 5.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let result = addcmul(
&input as &dyn SparseTensor,
&tensor1 as &dyn SparseTensor,
&tensor2 as &dyn SparseTensor,
0.5,
)
.unwrap();
let result_map: std::collections::HashMap<(usize, usize), f32> = result
.triplets()
.into_iter()
.map(|(r, c, v)| ((r, c), v))
.collect();
assert!((result_map.get(&(0, 0)).unwrap() - 5.0).abs() < 1e-5);
assert!((result_map.get(&(0, 1)).unwrap() - 7.5).abs() < 1e-5);
assert!((result_map.get(&(1, 1)).unwrap() - 2.0).abs() < 1e-5);
}
#[test]
fn test_addcdiv() {
let input = CooTensor::new(
vec![0, 1],
vec![0, 1],
vec![1.0, 2.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let tensor1 = CooTensor::new(
vec![0, 0],
vec![0, 1],
vec![8.0, 10.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let tensor2 = CooTensor::new(
vec![0, 0],
vec![0, 1],
vec![2.0, 5.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let result = addcdiv(
&input as &dyn SparseTensor,
&tensor1 as &dyn SparseTensor,
&tensor2 as &dyn SparseTensor,
0.5,
)
.unwrap();
let result_map: std::collections::HashMap<(usize, usize), f32> = result
.triplets()
.into_iter()
.map(|(r, c, v)| ((r, c), v))
.collect();
assert!((result_map.get(&(0, 0)).unwrap() - 3.0).abs() < 1e-5);
assert!((result_map.get(&(0, 1)).unwrap() - 1.0).abs() < 1e-5);
assert!((result_map.get(&(1, 1)).unwrap() - 2.0).abs() < 1e-5);
}
#[test]
fn test_masked_fill() {
let tensor = CooTensor::new(
vec![0, 0, 1, 1],
vec![0, 1, 0, 1],
vec![1.0, -2.0, 3.0, -4.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let result = masked_fill(&tensor as &dyn SparseTensor, |v| v < 0.0, 0.0).unwrap();
let result_map: std::collections::HashMap<(usize, usize), f32> = result
.triplets()
.into_iter()
.map(|(r, c, v)| ((r, c), v))
.collect();
assert_eq!(result_map.get(&(0, 0)).unwrap(), &1.0); assert_eq!(result_map.get(&(1, 0)).unwrap(), &3.0); assert_eq!(result_map.len(), 2); }
#[test]
fn test_clamp() {
let tensor = CooTensor::new(
vec![0, 0, 1, 1],
vec![0, 1, 0, 1],
vec![-2.0, 1.0, 5.0, 3.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let result = clamp(&tensor as &dyn SparseTensor, Some(0.0), Some(4.0)).unwrap();
let result_map: std::collections::HashMap<(usize, usize), f32> = result
.triplets()
.into_iter()
.map(|(r, c, v)| ((r, c), v))
.collect();
assert_eq!(result_map.get(&(0, 1)).unwrap(), &1.0); assert_eq!(result_map.get(&(1, 0)).unwrap(), &4.0); assert_eq!(result_map.get(&(1, 1)).unwrap(), &3.0); }
#[test]
fn test_abs_sparse() {
let tensor = CooTensor::new(
vec![0, 0, 1],
vec![0, 1, 1],
vec![-2.0, 3.0, -4.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let result = abs(&tensor as &dyn SparseTensor).unwrap();
let result_map: std::collections::HashMap<(usize, usize), f32> = result
.triplets()
.into_iter()
.map(|(r, c, v)| ((r, c), v))
.collect();
assert_eq!(result_map.get(&(0, 0)).unwrap(), &2.0);
assert_eq!(result_map.get(&(0, 1)).unwrap(), &3.0);
assert_eq!(result_map.get(&(1, 1)).unwrap(), &4.0);
}
#[test]
fn test_sign_sparse() {
let tensor = CooTensor::new(
vec![0, 0, 1],
vec![0, 1, 1],
vec![-2.0, 3.0, 0.0],
crate::Shape::new(vec![2, 2]),
)
.unwrap();
let result = sign(&tensor as &dyn SparseTensor).unwrap();
let result_map: std::collections::HashMap<(usize, usize), f32> = result
.triplets()
.into_iter()
.map(|(r, c, v)| ((r, c), v))
.collect();
assert_eq!(result_map.get(&(0, 0)).unwrap(), &-1.0);
assert_eq!(result_map.get(&(0, 1)).unwrap(), &1.0);
assert_eq!(result_map.len(), 2);
}
}