use crate::{CooTensor, CsrTensor, CscTensor, SparseTensor, SparseFormat, TorshResult};
use scirs2_core::random::{Random, rng};
use std::collections::HashMap;
use torsh_core::{Shape, TorshError};
use torsh_tensor::{
creation::{randn, zeros},
Tensor,
};
fn unzip_triplets(triplets: Vec<(usize, usize, f32)>) -> (Vec<usize>, Vec<usize>, Vec<f32>) {
triplets.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), (r, c, v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
)
}
#[derive(Debug, Clone)]
pub struct SparseReLU {
inplace: bool,
}
impl SparseReLU {
pub fn new(inplace: bool) -> Self {
Self { inplace }
}
pub fn default() -> Self {
Self::new(false)
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let activated_triplets: Vec<(usize, usize, f32)> = triplets
.into_iter()
.filter_map(|(row, col, val)| {
if val > 0.0 {
Some((row, col, val))
} else {
None }
})
.collect();
let (rows, cols, values) = unzip_triplets(activated_triplets);
let activated_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(activated_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&activated_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&activated_coo)?)),
}
}
pub fn inplace(&self) -> bool {
self.inplace
}
}
#[derive(Debug, Clone)]
pub struct SparseLeakyReLU {
negative_slope: f32,
}
impl SparseLeakyReLU {
pub fn new(negative_slope: f32) -> Self {
Self { negative_slope }
}
pub fn default() -> Self {
Self::new(0.01)
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let activated_triplets: Vec<(usize, usize, f32)> = triplets
.into_iter()
.map(|(row, col, val)| {
let leaky_relu_val = if val > 0.0 {
val
} else {
self.negative_slope * val
};
(row, col, leaky_relu_val)
})
.filter(|(_, _, val)| val.abs() > 1e-10) .collect();
let (rows, cols, values) = unzip_triplets(activated_triplets);
let activated_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(activated_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&activated_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&activated_coo)?)),
}
}
pub fn negative_slope(&self) -> f32 {
self.negative_slope
}
}
#[derive(Debug, Clone)]
pub struct SparseSigmoid;
impl Default for SparseSigmoid {
fn default() -> Self {
Self::new()
}
}
impl SparseSigmoid {
pub fn new() -> Self {
Self
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let activated_triplets: Vec<(usize, usize, f32)> = triplets
.into_iter()
.map(|(row, col, val)| {
let sigmoid_val = 1.0 / (1.0 + (-val).exp());
(row, col, sigmoid_val)
})
.collect();
let (rows, cols, values) = unzip_triplets(activated_triplets);
let activated_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(activated_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&activated_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&activated_coo)?)),
}
}
}
#[derive(Debug, Clone)]
pub struct SparseTanh;
impl Default for SparseTanh {
fn default() -> Self {
Self::new()
}
}
impl SparseTanh {
pub fn new() -> Self {
Self
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let activated_triplets: Vec<(usize, usize, f32)> = triplets
.into_iter()
.map(|(row, col, val)| {
let tanh_val = val.tanh();
(row, col, tanh_val)
})
.collect();
let (rows, cols, values) = unzip_triplets(activated_triplets);
let activated_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(activated_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&activated_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&activated_coo)?)),
}
}
}
#[derive(Debug, Clone)]
pub struct SparseGELU;
impl Default for SparseGELU {
fn default() -> Self {
Self::new()
}
}
impl SparseGELU {
pub fn new() -> Self {
Self
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let activated_triplets: Vec<(usize, usize, f32)> = triplets
.into_iter()
.map(|(row, col, val)| {
let sigmoid_val = 1.0 / (1.0 + (-1.702 * val).exp());
let gelu_val = val * sigmoid_val;
(row, col, gelu_val)
})
.filter(|(_, _, val)| val.abs() > 1e-10) .collect();
let (rows, cols, values) = unzip_triplets(activated_triplets);
let activated_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(activated_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&activated_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&activated_coo)?)),
}
}
}
#[derive(Debug, Clone)]
pub struct SparseSwish;
impl Default for SparseSwish {
fn default() -> Self {
Self::new()
}
}
impl SparseSwish {
pub fn new() -> Self {
Self
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let activated_triplets: Vec<(usize, usize, f32)> = triplets
.into_iter()
.map(|(row, col, val)| {
let sigmoid_val = 1.0 / (1.0 + (-val).exp());
let swish_val = val * sigmoid_val;
(row, col, swish_val)
})
.filter(|(_, _, val)| val.abs() > 1e-10) .collect();
let (rows, cols, values) = unzip_triplets(activated_triplets);
let activated_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(activated_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&activated_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&activated_coo)?)),
}
}
}
#[derive(Debug, Clone)]
pub struct SparseELU {
alpha: f32,
}
impl SparseELU {
pub fn new(alpha: f32) -> Self {
Self { alpha }
}
pub fn default() -> Self {
Self::new(1.0)
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let activated_triplets: Vec<(usize, usize, f32)> = triplets
.into_iter()
.map(|(row, col, val)| {
let elu_val = if val > 0.0 {
val
} else {
self.alpha * (val.exp() - 1.0)
};
(row, col, elu_val)
})
.filter(|(_, _, val)| val.abs() > 1e-10) .collect();
let (rows, cols, values) = unzip_triplets(activated_triplets);
let activated_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(activated_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&activated_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&activated_coo)?)),
}
}
pub fn alpha(&self) -> f32 {
self.alpha
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse_tensor::SparseFormat;
fn create_test_coo() -> CooTensor {
let row_indices = vec![0, 0, 1, 1, 2];
let col_indices = vec![0, 1, 0, 2, 1];
let values = vec![2.0, -1.0, 3.0, -2.0, 1.0];
let shape = Shape::new(vec![3, 3]);
CooTensor::new(row_indices, col_indices, values, shape).expect("Coo Tensor should succeed")
}
#[test]
fn test_sparse_relu() {
let relu = SparseReLU::new(false);
let input = create_test_coo();
let output = relu.forward(&input).expect("forward pass should succeed");
let output_coo = output.to_coo().expect("COO format conversion should succeed");
let triplets = output_coo.triplets();
assert_eq!(triplets.len(), 3);
for (_, _, val) in triplets {
assert!(val > 0.0);
}
}
#[test]
fn test_sparse_leaky_relu() {
let leaky_relu = SparseLeakyReLU::new(0.1);
let input = create_test_coo();
let output = leaky_relu.forward(&input).expect("forward pass should succeed");
let output_coo = output.to_coo().expect("COO format conversion should succeed");
let triplets = output_coo.triplets();
assert_eq!(triplets.len(), 5);
assert_eq!(leaky_relu.negative_slope(), 0.1);
}
#[test]
fn test_sparse_sigmoid() {
let sigmoid = SparseSigmoid::new();
let input = create_test_coo();
let output = sigmoid.forward(&input).expect("forward pass should succeed");
let output_coo = output.to_coo().expect("COO format conversion should succeed");
let triplets = output_coo.triplets();
assert_eq!(triplets.len(), 5);
for (_, _, val) in triplets {
assert!(val > 0.0 && val < 1.0);
}
}
#[test]
fn test_sparse_tanh() {
let tanh = SparseTanh::new();
let input = create_test_coo();
let output = tanh.forward(&input).expect("forward pass should succeed");
let output_coo = output.to_coo().expect("COO format conversion should succeed");
let triplets = output_coo.triplets();
assert_eq!(triplets.len(), 5);
for (_, _, val) in triplets {
assert!(val > -1.0 && val < 1.0);
}
}
#[test]
fn test_sparse_gelu() {
let gelu = SparseGELU::new();
let input = create_test_coo();
let output = gelu.forward(&input).expect("forward pass should succeed");
let output_coo = output.to_coo().expect("COO format conversion should succeed");
let triplets = output_coo.triplets();
assert!(triplets.len() > 0);
}
#[test]
fn test_sparse_swish() {
let swish = SparseSwish::new();
let input = create_test_coo();
let output = swish.forward(&input).expect("forward pass should succeed");
let output_coo = output.to_coo().expect("COO format conversion should succeed");
let triplets = output_coo.triplets();
assert!(triplets.len() > 0);
}
#[test]
fn test_sparse_elu() {
let elu = SparseELU::new(1.0);
let input = create_test_coo();
let output = elu.forward(&input).expect("forward pass should succeed");
let output_coo = output.to_coo().expect("COO format conversion should succeed");
let triplets = output_coo.triplets();
assert!(triplets.len() > 0);
assert_eq!(elu.alpha(), 1.0);
}
#[test]
fn test_activation_defaults() {
let _relu = SparseReLU::default();
let _leaky_relu = SparseLeakyReLU::default();
let _sigmoid = SparseSigmoid::default();
let _tanh = SparseTanh::default();
let _gelu = SparseGELU::default();
let _swish = SparseSwish::default();
let _elu = SparseELU::default();
}
#[test]
fn test_format_preservation() {
let relu = SparseReLU::new(false);
let coo_input = create_test_coo();
let csr_input = CsrTensor::from_coo(&coo_input).expect("Csr Tensor should succeed");
let coo_output = relu.forward(&coo_input).expect("forward pass should succeed");
let csr_output = relu.forward(&csr_input).expect("forward pass should succeed");
assert_eq!(coo_output.format(), SparseFormat::Coo);
assert_eq!(csr_output.format(), SparseFormat::Csr);
}
#[test]
fn test_sparsity_increase_with_relu() {
let relu = SparseReLU::new(false);
let input = create_test_coo();
let input_nnz = input.nnz();
let output = relu.forward(&input).expect("forward pass should succeed");
let output_nnz = output.nnz();
assert!(output_nnz <= input_nnz);
}
}