use crate::{FloatElement, Tensor, TensorElement};
use torsh_core::error::{Result, TorshError};
impl<T: FloatElement> Tensor<T> {
pub fn relu(&self) -> Result<Self> {
let data = self.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let zero = <T as TensorElement>::zero();
if x > zero {
x
} else {
zero
}
})
.collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn relu_(&mut self) -> Result<()>
where
T: PartialOrd,
{
let zero = <T as TensorElement>::zero();
self.data_mut_apply(|item| {
if *item < zero {
*item = zero;
}
})?;
Ok(())
}
pub fn leaky_relu(&self, negative_slope: f64) -> Result<Self> {
let data = self.data()?;
let slope = T::from_f64(negative_slope).unwrap_or_else(|| <T as TensorElement>::zero());
let zero = <T as TensorElement>::zero();
let result_data: Vec<T> = data
.iter()
.map(|&x| if x > zero { x } else { x * slope })
.collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn sigmoid(&self) -> Result<Self> {
let data = self.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let one = <T as TensorElement>::one();
one / (one + (-x).exp())
})
.collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn sigmoid_(&mut self) -> Result<()> {
let one = <T as TensorElement>::one();
self.data_mut_apply(|item| {
*item = one / (one + (-*item).exp());
})?;
Ok(())
}
pub fn tanh(&self) -> Result<Self> {
let data = self.data()?;
let result_data: Vec<T> = data.iter().map(|&x| x.tanh()).collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn tanh_(&mut self) -> Result<()> {
self.data_mut_apply(|item| {
*item = item.tanh();
})?;
Ok(())
}
pub fn gelu(&self) -> Result<Self> {
let data = self.data()?;
let half = T::from_f64(0.5).unwrap_or_else(|| <T as TensorElement>::zero());
let one = <T as TensorElement>::one();
let c1 = T::from_f64(0.7978845608).unwrap_or_else(|| <T as TensorElement>::zero()); let c2 = T::from_f64(0.044715).unwrap_or_else(|| <T as TensorElement>::zero());
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let x3 = x * x * x;
let inner = c1 * (x + c2 * x3);
half * x * (one + inner.tanh())
})
.collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn softmax(&self, dim: i32) -> Result<Self> {
let data = self.data()?;
let shape_binding = self.shape();
let shape = shape_binding.dims();
if shape.is_empty() {
return Err(TorshError::InvalidOperation("Cannot compute softmax on empty tensor".to_string()));
}
let actual_dim = if dim < 0 {
(shape.len() as i32 + dim) as usize
} else {
dim as usize
};
if actual_dim >= shape.len() {
return Err(TorshError::InvalidArgument(format!("Dimension {} out of range for tensor with {} dimensions", dim, shape.len())));
}
let dim_size = shape[actual_dim];
let outer_size: usize = shape[..actual_dim].iter().product();
let inner_size: usize = shape[actual_dim + 1..].iter().product();
let mut result_data = vec![T::from_f64(0.0).expect("f64 conversion should succeed"); data.len()];
for outer in 0..outer_size {
for inner in 0..inner_size {
let base_idx = outer * dim_size * inner_size + inner;
let mut max_val = data[base_idx];
for d in 1..dim_size {
let idx = base_idx + d * inner_size;
if data[idx] > max_val {
max_val = data[idx];
}
}
let mut exp_sum = T::from_f64(0.0).expect("f64 conversion should succeed");
let mut exp_values = vec![T::from_f64(0.0).expect("f64 conversion should succeed"); dim_size];
for d in 0..dim_size {
let idx = base_idx + d * inner_size;
let exp_val = (data[idx] - max_val).exp();
exp_values[d] = exp_val;
exp_sum = exp_sum + exp_val;
}
for d in 0..dim_size {
let idx = base_idx + d * inner_size;
result_data[idx] = exp_values[d] / exp_sum;
}
}
}
Self::from_data(
result_data,
shape.to_vec(),
self.device,
)
}
pub fn log_softmax(&self, dim: i32) -> Result<Self> {
let data = self.data()?;
let shape_binding = self.shape();
let shape = shape_binding.dims();
if shape.is_empty() {
return Err(TorshError::InvalidOperation("Cannot compute log_softmax on empty tensor".to_string()));
}
let actual_dim = if dim < 0 {
(shape.len() as i32 + dim) as usize
} else {
dim as usize
};
if actual_dim >= shape.len() {
return Err(TorshError::InvalidArgument(format!("Dimension {} out of range for tensor with {} dimensions", dim, shape.len())));
}
let dim_size = shape[actual_dim];
let outer_size: usize = shape[..actual_dim].iter().product();
let inner_size: usize = shape[actual_dim + 1..].iter().product();
let mut result_data = vec![T::from_f64(0.0).expect("f64 conversion should succeed"); data.len()];
for outer in 0..outer_size {
for inner in 0..inner_size {
let base_idx = outer * dim_size * inner_size + inner;
let mut max_val = data[base_idx];
for d in 1..dim_size {
let idx = base_idx + d * inner_size;
if data[idx] > max_val {
max_val = data[idx];
}
}
let mut exp_sum = T::from_f64(0.0).expect("f64 conversion should succeed");
for d in 0..dim_size {
let idx = base_idx + d * inner_size;
exp_sum = exp_sum + (data[idx] - max_val).exp();
}
let log_sum_exp = exp_sum.ln();
for d in 0..dim_size {
let idx = base_idx + d * inner_size;
result_data[idx] = data[idx] - max_val - log_sum_exp;
}
}
}
Self::from_data(
result_data,
shape.to_vec(),
self.device,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_relu() {
let tensor = Tensor::from_data(vec![-2.0f32, -1.0, 0.0, 1.0, 2.0], vec![5], DeviceType::Cpu).expect("tensor creation failed");
let result = tensor.relu().expect("relu failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data.as_slice(), &[0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_relu_inplace() {
let mut tensor = Tensor::from_data(vec![-2.0f32, -1.0, 0.0, 1.0, 2.0], vec![5], DeviceType::Cpu).expect("tensor creation failed");
tensor.relu_().expect("relu_ failed");
let data = tensor.data().expect("data retrieval failed");
assert_eq!(data.as_slice(), &[0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_leaky_relu() {
let tensor = Tensor::from_data(vec![-2.0f32, -1.0, 0.0, 1.0, 2.0], vec![5], DeviceType::Cpu).expect("tensor creation failed");
let result = tensor.leaky_relu(0.1).expect("leaky_relu failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data.as_slice(), &[-0.2, -0.1, 0.0, 1.0, 2.0]);
}
#[test]
fn test_sigmoid() {
let tensor = Tensor::from_data(vec![0.0f32, 1.0, -1.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let result = tensor.sigmoid().expect("sigmoid failed");
let data = result.data().expect("data retrieval failed");
assert!((data[0] - 0.5).abs() < 1e-6);
assert!((data[1] - 0.7310586).abs() < 1e-6);
assert!((data[2] - 0.26894143).abs() < 1e-6);
}
#[test]
fn test_tanh() {
let tensor = Tensor::from_data(vec![0.0f32, 1.0, -1.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let result = tensor.tanh().expect("tanh failed");
let data = result.data().expect("data retrieval failed");
assert!((data[0] - 0.0).abs() < 1e-6);
assert!((data[1] - 0.7615942).abs() < 1e-6);
assert!((data[2] - (-0.7615942)).abs() < 1e-6);
}
#[test]
fn test_gelu() {
let tensor = Tensor::from_data(vec![0.0f32, 1.0, -1.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let result = tensor.gelu().expect("gelu failed");
let data = result.data().expect("data retrieval failed");
assert!((data[0] - 0.0).abs() < 1e-5); assert!(data[1] > 0.8); assert!(data[2] < -0.1); }
#[test]
fn test_softmax() {
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let result = tensor.softmax(-1).expect("softmax failed");
let data = result.data().expect("data retrieval failed");
let sum: f32 = data.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(data[2] > data[1]);
assert!(data[1] > data[0]);
}
#[test]
fn test_log_softmax() {
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let result = tensor.log_softmax(-1).expect("log_softmax failed");
let data = result.data().expect("data retrieval failed");
assert!(data[0] < 0.0);
assert!(data[1] < 0.0);
assert!(data[2] < 0.0);
assert!(data[2] > data[1]);
assert!(data[1] > data[0]);
}
#[test]
fn test_softmax_2d() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DeviceType::Cpu
).expect("tensor creation failed");
let result = tensor.softmax(-1).expect("softmax failed");
let data = result.data().expect("data retrieval failed");
let row1_sum: f32 = data[0..3].iter().sum();
let row2_sum: f32 = data[3..6].iter().sum();
assert!((row1_sum - 1.0).abs() < 1e-6);
assert!((row2_sum - 1.0).abs() < 1e-6);
}
}