use approx::assert_relative_eq;
use torsh_tensor::creation::*;
use torsh_tensor::tensor;
#[test]
fn test_tensor_creation() {
let scalar = tensor![5.0f32].unwrap();
assert_eq!(scalar.shape().dims(), &[] as &[usize]);
assert_eq!(scalar.numel(), 1);
assert_eq!(scalar.item().unwrap(), 5.0f32);
let vec1d = tensor![1.0f32, 2.0f32, 3.0f32].unwrap();
assert_eq!(vec1d.shape().dims(), &[3]);
assert_eq!(vec1d.numel(), 3);
let mat2d = tensor_2d(&[&[1.0f32, 2.0f32], &[3.0f32, 4.0f32]]).unwrap();
assert_eq!(mat2d.shape().dims(), &[2, 2]);
assert_eq!(mat2d.numel(), 4);
}
#[test]
fn test_zeros_ones() {
let z = zeros::<f32>(&[3, 4]).unwrap();
assert_eq!(z.shape().dims(), &[3, 4]);
assert_eq!(z.numel(), 12);
let o = ones::<f32>(&[2, 2]).unwrap();
assert_eq!(o.shape().dims(), &[2, 2]);
assert_eq!(o.numel(), 4);
let e = eye::<f32>(3).unwrap();
assert_eq!(e.shape().dims(), &[3, 3]);
}
#[test]
fn test_basic_operations() {
let a = tensor![1.0, 2.0, 3.0].unwrap();
let b = tensor![4.0, 5.0, 6.0].unwrap();
let c = a.add(&b).unwrap();
let expected = vec![5.0, 7.0, 9.0];
assert_eq!(c.to_vec().unwrap(), expected);
let d = b.sub(&a).unwrap();
let expected = vec![3.0, 3.0, 3.0];
assert_eq!(d.to_vec().unwrap(), expected);
let e = a.mul_op(&b).unwrap();
let expected = vec![4.0, 10.0, 18.0];
assert_eq!(e.to_vec().unwrap(), expected);
let f = b.div(&a).unwrap();
assert_relative_eq!(f.to_vec().unwrap()[0], 4.0, epsilon = 1e-6);
assert_relative_eq!(f.to_vec().unwrap()[1], 2.5, epsilon = 1e-6);
assert_relative_eq!(f.to_vec().unwrap()[2], 2.0, epsilon = 1e-6);
}
#[test]
fn test_matrix_multiplication() {
let a = tensor_2d(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let b = tensor_2d(&[&[5.0, 6.0], &[7.0, 8.0]]).unwrap();
let c = a.matmul(&b).unwrap();
let expected = vec![19.0, 22.0, 43.0, 50.0];
assert_eq!(c.to_vec().unwrap(), expected);
}
#[test]
fn test_transpose() {
let a = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
let t = a.t().unwrap();
assert_eq!(t.shape().dims(), &[3, 2]);
let expected = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
assert_eq!(t.to_vec().unwrap(), expected);
}
#[test]
fn test_reductions() {
let a = tensor_2d(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let sum = a.sum().unwrap();
assert_eq!(sum.item().unwrap(), 10.0);
let mean = a.mean(None, false).unwrap();
assert_eq!(mean.item().unwrap(), 2.5);
let max = a.max(None, false).unwrap();
assert_eq!(max.item().unwrap(), 4.0);
let min = a.min().unwrap();
assert_eq!(min.item().unwrap(), 1.0);
}
#[test]
fn test_activations() {
let a = tensor_2d(&[&[-1.0, 0.0, 1.0, 2.0]]).unwrap();
let relu = a.relu().unwrap();
let expected = vec![0.0, 0.0, 1.0, 2.0];
assert_eq!(relu.to_vec().unwrap(), expected);
let sigmoid = a.sigmoid().unwrap();
assert_relative_eq!(sigmoid.to_vec().unwrap()[1], 0.5, epsilon = 1e-6);
let tanh = a.tanh().unwrap();
assert_relative_eq!(tanh.to_vec().unwrap()[1], 0.0, epsilon = 1e-6);
}
#[test]
fn test_broadcasting() {
let a = ones::<f32>(&[3, 1]).unwrap();
let b = ones::<f32>(&[1, 4]).unwrap();
let c = a.add(&b).unwrap();
assert_eq!(c.shape().dims(), &[3, 4]);
assert_eq!(c.to_vec().unwrap(), vec![2.0; 12]);
}
#[test]
fn test_gradient_tracking() {
let x = tensor![2.0].unwrap().requires_grad_(true);
assert!(x.requires_grad());
let y = x.pow(2.0).unwrap();
assert!(y.requires_grad());
y.backward().unwrap();
let grad = x.grad().unwrap();
assert_eq!(grad.item().unwrap(), 4.0); }
#[test]
fn test_random_tensors() {
let r = rand::<f32>(&[3, 3]).unwrap();
assert_eq!(r.shape().dims(), &[3, 3]);
for val in r.to_vec().unwrap() {
assert!((0.0..1.0).contains(&val));
}
let n = randn::<f32>(&[2, 2]).unwrap();
assert_eq!(n.shape().dims(), &[2, 2]);
}
#[test]
fn test_arange_linspace() {
let a = arange(0.0f32, 5.0, 1.0).unwrap();
assert_eq!(a.to_vec().unwrap(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
let l = linspace(0.0f32, 1.0, 5).unwrap();
let expected = [0.0, 0.25, 0.5, 0.75, 1.0];
for (actual, expected) in l.to_vec().unwrap().iter().zip(expected.iter()) {
assert_relative_eq!(actual, expected, epsilon = 1e-6);
}
}
#[test]
fn test_scalar_operations() {
let a = tensor![1.0, 2.0, 3.0].unwrap();
let b = a.add_scalar(5.0).unwrap();
assert_eq!(b.to_vec().unwrap(), vec![6.0, 7.0, 8.0]);
let c = a.mul_scalar(2.0).unwrap();
assert_eq!(c.to_vec().unwrap(), vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_shape_errors() {
let a = tensor_2d(&[&[1.0, 2.0, 3.0]]).unwrap(); let b = tensor_2d(&[&[4.0, 5.0], &[6.0, 7.0]]).unwrap();
assert!(a.add(&b).is_err());
let c = tensor_2d(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); assert!(a.matmul(&c).is_err()); }