use super::super::core::Tensor;
use num_traits::Float;
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Tensor<T> {
pub fn exp(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.exp()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn ln(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.ln()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn sin(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.sin()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn cos(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.cos()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn tan(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.tan()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn abs(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.abs()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn pow(&self, exponent: T) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.powf(exponent)).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mathematical_functions() {
let tensor = Tensor::from_vec(vec![1.0, 4.0, 9.0, 16.0], vec![2, 2]);
let sqrt_result = tensor.sqrt();
let expected: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
assert_eq!(sqrt_result.as_slice().unwrap(), expected);
let exp_result = tensor.exp();
assert_eq!(exp_result.shape(), tensor.shape());
let ln_result = tensor.ln();
assert_eq!(ln_result.shape(), tensor.shape());
}
#[test]
fn test_trigonometric_functions() {
let tensor = Tensor::from_vec(vec![0.0, std::f32::consts::PI / 2.0], vec![2]);
let sin_result = tensor.sin();
let cos_result = tensor.cos();
assert!((sin_result.as_slice().unwrap()[0] - 0.0).abs() < 1e-6);
assert!((sin_result.as_slice().unwrap()[1] - 1.0).abs() < 1e-6);
assert!((cos_result.as_slice().unwrap()[0] - 1.0).abs() < 1e-6);
assert!((cos_result.as_slice().unwrap()[1]).abs() < 1e-6);
}
}