use zyx::{Tensor, ZyxError};
#[test]
fn sum_1() -> Result<(), ZyxError> {
let x = Tensor::from([2, 4]);
assert_eq!(x.sum_all(), 6);
Ok(())
}
#[test]
fn sum_2() -> Result<(), ZyxError> {
let x = Tensor::from([[4, 1, 3], [5, 2, 3], [6, 5, 7]]);
let x0 = x.sum([-1])?;
let x1 = x.sum([-2])?;
let x2 = x.sum_all();
assert_eq!(x0, [8, 10, 18]);
assert_eq!(x1, [15, 8, 13]);
assert_eq!(x2, [36]);
Ok(())
}
#[test]
fn sum_3() -> Result<(), ZyxError> {
let x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
assert_eq!(x.sum([0])?, [3, 9, 4]);
assert_eq!(x.sum([1])?, [9, 7]);
assert_eq!(x.sum_all(), 16);
Ok(())
}
#[test]
fn sum_4() -> Result<(), ZyxError> {
let x = Tensor::from([[4, 1, 3], [5, 2, 3], [6, 5, 7]]);
let x0 = x.relu().sum([-1])?;
assert_eq!(x0, [8, 10, 18]);
Ok(())
}
#[test]
fn sum_5() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
x = x.sum_all();
debug_assert_eq!(x, [13i32]);
Ok(())
}
#[test]
fn sum_6() -> Result<(), ZyxError> {
let x = Tensor::from([[4i32, 1, 3], [5, 2, 3], [6, 5, 7]]);
let x = x.sum([0])?;
let x = x.sum([0])?;
debug_assert_eq!(x, [36i32]);
Ok(())
}
#[test]
fn max_1() -> Result<(), ZyxError> {
let x = Tensor::from([[4, 1, 3], [5, 2, 3], [6, 5, 7]]);
let x0 = x.max([-1])?;
let x1 = x.max([-2])?;
let x2 = x.max_all();
assert_eq!(x0, [4, 5, 7]);
assert_eq!(x1, [6, 5, 7]);
assert_eq!(x2, [7]);
Ok(())
}
#[test]
fn sum_large_2d() -> Result<(), ZyxError> {
let x = Tensor::from([
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30],
[31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
[41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
[51, 52, 53, 54, 55, 56, 57, 58, 59, 60],
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70],
[71, 72, 73, 74, 75, 76, 77, 78, 79, 80],
]);
let x0 = x.sum([-1])?;
assert_eq!(x0, [55, 155, 255, 355, 455, 555, 655, 755]);
Ok(())
}
#[test]
fn max_large_3d() -> Result<(), ZyxError> {
let x = Tensor::from([
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]],
]);
let x0 = x.max([-1])?;
assert_eq!(x0, [[3, 6, 9], [12, 15, 18], [21, 24, 27]]);
Ok(())
}
#[test]
fn sum_warp_reduce_1() -> Result<(), ZyxError> {
let x = Tensor::from(vec![1i32; 512]);
let x = x.reshape([8, 64])?;
let x0 = x.sum([-1])?;
assert_eq!(x0, vec![64i32; 8]);
Ok(())
}
#[test]
fn sum_warp_reduce_large() -> Result<(), ZyxError> {
let x = Tensor::from(vec![1i32; 2048]);
let x = x.reshape([8, 256])?;
let x0 = x.sum([-1])?;
assert_eq!(x0, vec![256i32; 8]);
Ok(())
}
#[test]
fn sum_reduce_32k() -> Result<(), ZyxError> {
let x = Tensor::from(vec![1i32; 32768]);
let x = x.reshape([8, 4096])?;
let x0 = x.sum([-1])?;
assert_eq!(x0, vec![4096i32; 8]);
Ok(())
}
#[test]
fn sum_reduce_last_dim_3d() -> Result<(), ZyxError> {
let data: Vec<i32> = (0..(4 * 1024 * 512)).map(|i| (i % 10) as i32).collect();
let x = Tensor::from(data);
let x = x.reshape([4, 1024, 512])?;
let x0 = x.sum([-1])?;
let expected: [[i32; 1024]; 4] = {
let mut arr = [[0i32; 1024]; 4];
for i in 0..4 {
for j in 0..1024 {
let row = i * 1024 + j;
let start = row * 512;
arr[i][j] = (start..start + 512).map(|k| (k % 10) as i32).sum();
}
}
arr
};
if x0 != expected {
panic!("test failure");
}
Ok(())
}