use super::super::core::Tensor;
use crate::error::{RusTorchError, RusTorchResult};
use num_traits::Float;
use std::f64::consts::PI;
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Tensor<T> {
pub fn sinh(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.sinh()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn cosh(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.cosh()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn tanh(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.tanh()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn asinh(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.asinh()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn acosh(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.acosh()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn atanh(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.atanh()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn asin(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.asin()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn acos(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.acos()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn atan(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.atan()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn atan2(&self, other: &Tensor<T>) -> RusTorchResult<Self> {
if self.shape() != other.shape() && !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&y, &x)| y.atan2(x))
.collect();
Ok(Tensor::from_vec(result_data, self.shape().to_vec()))
} else {
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(|(&y, &x)| y.atan2(x))
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
}
pub fn floor(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.floor()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn ceil(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.ceil()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn round(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.round()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn trunc(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.trunc()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn pow_tensor(&self, exponent: &Tensor<T>) -> RusTorchResult<Self> {
if self.shape() != exponent.shape() && !self.can_broadcast_with(exponent) {
return Err(RusTorchError::shape_mismatch(
self.shape(),
exponent.shape(),
));
}
if self.shape() == exponent.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(exponent.data.iter())
.map(|(&base, &exp)| base.powf(exp))
.collect();
Ok(Tensor::from_vec(result_data, self.shape().to_vec()))
} else {
let (broadcasted_self, broadcasted_exp) = self.broadcast_with(exponent)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_exp.data.iter())
.map(|(&base, &exp)| base.powf(exp))
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
}
pub fn square(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x * x).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn log10(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.log10()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn log2(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x.log2()).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn log(&self, base: T) -> Self {
let log_base = base.ln();
let result_data: Vec<T> = self.data.iter().map(|&x| x.ln() / log_base).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn clamp(&self, min_val: T, max_val: T) -> Self {
let result_data: Vec<T> = self
.data
.iter()
.map(|&x| {
if x < min_val {
min_val
} else if x > max_val {
max_val
} else {
x
}
})
.collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn sign(&self) -> Self {
let result_data: Vec<T> = self
.data
.iter()
.map(|&x| {
if x > T::zero() {
T::one()
} else if x < T::zero() {
-T::one()
} else {
T::zero()
}
})
.collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn lerp(&self, other: &Tensor<T>, weight: T) -> RusTorchResult<Self> {
if self.shape() != other.shape() && !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| a + weight * (b - a))
.collect();
Ok(Tensor::from_vec(result_data, self.shape().to_vec()))
} else {
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(|(&a, &b)| a + weight * (b - a))
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hyperbolic_functions() {
let tensor = Tensor::from_vec(vec![0.0, 1.0, -1.0], vec![3]);
let sinh_result = tensor.sinh();
let cosh_result = tensor.cosh();
let tanh_result = tensor.tanh();
assert_eq!(sinh_result.as_slice().unwrap()[0], 0.0);
assert_eq!(cosh_result.as_slice().unwrap()[0], 1.0);
assert_eq!(tanh_result.as_slice().unwrap()[0], 0.0);
assert!(
(sinh_result.as_slice().unwrap()[1] + sinh_result.as_slice().unwrap()[2]).abs() < 1e-10
);
}
#[test]
fn test_inverse_trigonometric_functions() {
let tensor = Tensor::from_vec(vec![0.0, 0.5, -0.5], vec![3]);
let asin_result = tensor.asin();
let acos_result = tensor.acos();
let atan_result = tensor.atan();
assert_eq!(asin_result.as_slice().unwrap()[0], 0.0);
assert!((acos_result.as_slice().unwrap()[0] - std::f64::consts::PI / 2.0).abs() < 1e-10);
assert_eq!(atan_result.as_slice().unwrap()[0], 0.0);
}
#[test]
fn test_rounding_functions() {
let tensor = Tensor::from_vec(vec![1.2, 2.7, -1.3, -2.8], vec![4]);
let floor_result = tensor.floor();
let ceil_result = tensor.ceil();
let round_result = tensor.round();
let trunc_result = tensor.trunc();
assert_eq!(floor_result.as_slice().unwrap(), &[1.0, 2.0, -2.0, -3.0]);
assert_eq!(ceil_result.as_slice().unwrap(), &[2.0, 3.0, -1.0, -2.0]);
assert_eq!(round_result.as_slice().unwrap(), &[1.0, 3.0, -1.0, -3.0]);
assert_eq!(trunc_result.as_slice().unwrap(), &[1.0, 2.0, -1.0, -2.0]);
}
#[test]
fn test_power_functions() {
let tensor = Tensor::from_vec(vec![1.0, 4.0, 9.0, 16.0], vec![4]);
let sqrt_result = tensor.sqrt();
let square_result = tensor.square();
let pow_result = tensor.pow(0.5);
assert_eq!(sqrt_result.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(square_result.as_slice().unwrap(), &[1.0, 16.0, 81.0, 256.0]);
assert_eq!(
pow_result.as_slice().unwrap(),
sqrt_result.as_slice().unwrap()
);
}
#[test]
fn test_comparison_functions() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let b = Tensor::from_vec(vec![2.0, 1.0, 3.0], vec![3]);
let max_result = a.maximum(&b).unwrap();
let min_result = a.minimum(&b).unwrap();
assert_eq!(max_result.as_slice().unwrap(), &[2.0, 2.0, 3.0]);
assert_eq!(min_result.as_slice().unwrap(), &[1.0, 1.0, 3.0]);
}
#[test]
fn test_clamp_and_sign() {
let tensor = Tensor::from_vec(vec![-2.0, -0.5, 0.0, 0.5, 2.0], vec![5]);
let clamped = tensor.clamp(-1.0, 1.0);
let sign_result = tensor.sign();
assert_eq!(clamped.as_slice().unwrap(), &[-1.0, -0.5, 0.0, 0.5, 1.0]);
assert_eq!(
sign_result.as_slice().unwrap(),
&[-1.0, -1.0, 0.0, 1.0, 1.0]
);
}
#[test]
fn test_atan2() {
let y = Tensor::from_vec(vec![1.0, 1.0, -1.0, -1.0], vec![4]);
let x = Tensor::from_vec(vec![1.0, -1.0, 1.0, -1.0], vec![4]);
let atan2_result = y.atan2(&x).unwrap();
let expected = vec![
std::f64::consts::PI / 4.0,
3.0 * std::f64::consts::PI / 4.0,
-std::f64::consts::PI / 4.0,
-3.0 * std::f64::consts::PI / 4.0,
];
let result_slice = atan2_result.as_slice().unwrap();
for (i, &expected_val) in expected.iter().enumerate() {
assert!((result_slice[i] - expected_val).abs() < 1e-10);
}
}
#[test]
fn test_lerp() {
let a = Tensor::from_vec(vec![0.0, 2.0, 4.0], vec![3]);
let b = Tensor::from_vec(vec![10.0, 20.0, 40.0], vec![3]);
let lerp_result = a.lerp(&b, 0.5).unwrap();
assert_eq!(lerp_result.as_slice().unwrap(), &[5.0, 11.0, 22.0]);
let lerp_result_0 = a.lerp(&b, 0.0).unwrap();
assert_eq!(lerp_result_0.as_slice().unwrap(), a.as_slice().unwrap());
let lerp_result_1 = a.lerp(&b, 1.0).unwrap();
assert_eq!(lerp_result_1.as_slice().unwrap(), b.as_slice().unwrap());
}
}