use zyx::{DType, Scalar, Tensor, ZyxError};
#[allow(unused)]
fn matmul(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let mut c = vec![0f32; m * n];
unsafe {
matrixmultiply::sgemm(
m,
k,
n,
1.0,
a.as_ptr(),
k as isize,
1,
b.as_ptr(),
n as isize,
1,
0.0,
c.as_mut_ptr(),
n as isize,
1,
)
};
c
}
#[test]
fn test_max_pool() -> Result<(), ZyxError> {
let input = Tensor::from([
[1.0f32, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
]);
let output = input.max_pool(
[2, 2], [2, 2], [1, 1], [(0, 0), (0, 0)], false, false, )?;
assert_eq!(output.shape(), [2, 2]);
assert_eq!(output, [[6.0f32, 8.0], [14.0, 16.0]]);
Ok(())
}
#[test]
fn memory1() {
let x = Tensor::from([[2, 3], [4, 5]]);
assert_eq!(x, [[2, 3], [4, 5]]);
}
#[test]
fn memory2() -> Result<(), ZyxError> {
let x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
assert_eq!(x, [[2, 4, 3], [1, 5, 1]]);
Ok(())
}
#[test]
fn complex_binary() -> Result<(), ZyxError> {
let x = Tensor::from([[2, 4, 3], [1, 5, 1]]).cast(DType::F32);
let y = Tensor::from([[2, 4, 3], [1, 5, 7]]).cast(DType::F32);
let z = x.sqrt() + y.exp2();
Tensor::realize([&z])?;
Ok(())
}
#[test]
fn tri1() -> Result<(), ZyxError> {
let x = Tensor::tri(3, 5, 2, DType::I32);
assert_eq!(x, [[0i32, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1]]);
Ok(())
}
#[test]
fn fuse_1() -> Result<(), ZyxError> {
let x = Tensor::from([[2f32, 4., 3.], [1., 5., 1.]]);
let z = x.exp2() + x;
assert_eq!(z, [[6f32, 20., 11.], [3., 37., 3.]]);
Ok(())
}
#[test]
fn fuse_2() -> Result<(), ZyxError> {
let x = Tensor::from([[2f32, 4., 3.], [1., 5., 1.]]);
let z = x.expand([2, 2, 3])? + x;
assert_eq!(z, [[[4f32, 8., 6.], [2., 10., 2.]], [[4., 8., 6.], [2., 10., 2.]]]);
Ok(())
}
#[test]
fn fuse_3() -> Result<(), ZyxError> {
let x = Tensor::from([[2f32, 4., 3.], [1., 5., 1.]]);
let z = x.sum([0])?.expand([2, 3])? + x;
assert_eq!(z, [[5f32, 13., 7.], [4., 14., 5.]]);
Ok(())
}
#[test]
fn fuse_4() -> Result<(), ZyxError> {
let x = Tensor::from([[2f32, 4., 3.], [1., 5., 1.]]);
let y = Tensor::from([[2f32, 4., 3.], [1., 5., 3.]]).exp2();
let z1 = x + &y;
let z2 = y.exp2();
Tensor::realize([&z1, &z2])?;
Ok(())
}
#[test]
fn fuse_5() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2i32, 4, 3], [1, 5, 1]]);
x = x.t();
let mut y = x.log2();
x = x.exp2();
x = x.reshape([2, 3])?;
y = y.t();
Tensor::realize([&x, &y])?;
Ok(())
}
#[test]
fn fuse_6() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2i32, 4, 3], [1, 5, 1]]);
x = x.sum([1])?;
let y = x.log2();
let x = x.exp2();
Tensor::realize([&x, &y])?;
assert_eq!(x, [512f32, 128.]);
assert_eq!(y, [3.16993f32, 2.807355]);
Ok(())
}
#[test]
fn matmul_2() -> Result<(), ZyxError> {
let x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
let y = Tensor::from([[2, 4], [3, 1], [5, 1]]);
let z = x.dot(y)?;
assert_eq!(z, [[31, 15], [22, 10]]);
Ok(())
}
#[test]
fn matmul_1() -> Result<(), ZyxError> {
for m in (56..576).step_by(259) {
for k in (12..890).step_by(231) {
for n in (5..97).step_by(71) {
let x_data: Vec<Vec<i32>> = (0..m).map(|i| (0..k).map(|j| i as i32 + j as i32).collect()).collect();
let y_data: Vec<Vec<i32>> = (0..k).map(|i| (0..n).map(|j| i as i32 - j as i32).collect()).collect();
let x = Tensor::from(x_data.clone());
let y = Tensor::from(y_data.clone());
let z = x.dot(y)?;
let mut expected = vec![vec![0i32; n]; m];
for i in 0..m {
for kk in 0..k {
for j in 0..n {
expected[i][j] += x_data[i][kk] * y_data[kk][j];
}
}
}
if z != expected {
panic!();
}
}
}
}
Ok(())
}
#[test]
fn batched_matmul() -> Result<(), ZyxError> {
for b in (19..25).step_by(4) {
for m in (32..67).step_by(31) {
for k in (16..256).step_by(173) {
for n in (8..128).step_by(59) {
let x_data: Vec<Vec<Vec<i32>>> = (0..b)
.map(|bb| {
(0..m)
.map(|i| (0..k).map(|j| bb as i32 + i as i32 + j as i32).collect())
.collect()
})
.collect();
let y_data: Vec<Vec<Vec<i32>>> = (0..b)
.map(|bb| {
(0..k)
.map(|i| (0..n).map(|j| bb as i32 + i as i32 - j as i32).collect())
.collect()
})
.collect();
let x = Tensor::from(x_data.clone());
let y = Tensor::from(y_data.clone());
let z = x.dot(y)?;
let mut expected = vec![vec![vec![0i32; n]; m]; b];
for bb in 0..b {
for i in 0..m {
for kk in 0..k {
for j in 0..n {
expected[bb][i][j] += x_data[bb][i][kk] * y_data[bb][kk][j];
}
}
}
}
let expected_shape = vec![b as u64, m as u64, n as u64];
assert_eq!(
z.shape(),
expected_shape,
"Shape mismatch: expected {:?}, got {:?}",
expected_shape,
z.shape()
);
assert_eq!(z.dtype(), DType::I32, "Dtype mismatch: expected I32, got {:?}", z.dtype());
if z != expected {
panic!("Batched matmul mismatch for b={}, m={}, k={}, n={}", b, m, k, n);
}
}
}
}
}
Ok(())
}
#[test]
fn boolean_buffer() -> Result<(), ZyxError> {
let x = Tensor::from([true, true, false, true]);
assert_eq!(x, [true, true, false, true]);
Ok(())
}
#[test]
fn mix_expand_reduce() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2i32, 4, 3], [1, 5, 1]]);
x = x.sum([1])?;
x = x.expand([2, 2])?;
assert_eq!(x, [[9i32, 7], [9, 7]]);
Ok(())
}
#[test]
fn mix_pad_reduce() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2i32, 4, 3], [1, 5, 1]]);
x = x.sum([1])?;
x = x.rpad_zeros([(0, 1)])?;
assert_eq!(x, [9i32, 7, 0]);
Ok(())
}
#[test]
fn mix_permute_pad() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2i32, 4, 3], [1, 5, 1]]);
x = x.rpad_zeros([(1, 0)])?.t();
assert_eq!(x, [[0i32, 0], [2, 1], [4, 5], [3, 1]]);
Ok(())
}
#[test]
fn mix_expand_reshape_reduce() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2i32, 4, 3], [1, 5, 1]]);
x = x.sum([1])?;
let y = x.expand([2, 2])?;
x = x.reshape([2, 1])?.expand([2, 2])?;
Tensor::realize([&x, &y])?;
assert_eq!(y, [[9i32, 7], [9, 7]]);
assert_eq!(x, [[9i32, 9], [7, 7]]);
Ok(())
}
#[test]
fn mix_pad_reshape_expand() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 4, 3, 3, 4], [1, 2, 1, 5, 1]]);
x = x.rpad_zeros([(1, 0), (2, 1)])?;
x = x.reshape([2, 1, 3, 5])?;
x = x.expand([2, 2, 3, 5])?;
assert_eq!(
x,
[
[
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 2, 4]],
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 2, 4]]
],
[
[[3, 3, 4, 0, 1], [2, 1, 5, 1, 0], [0, 0, 0, 0, 0]],
[[3, 3, 4, 0, 1], [2, 1, 5, 1, 0], [0, 0, 0, 0, 0]]
]
]
);
Ok(())
}
#[test]
fn mix_reshape1() -> Result<(), ZyxError> {
let mut x = Tensor::from([[[[2i32], [4]], [[3], [1]], [[5], [1]]]]);
x = x.permute([0, 2, 1, 3])?;
assert_eq!(x.shape(), [1, 2, 3, 1]);
x = x.reshape([1, 2, 1, 3, 1]).unwrap();
Tensor::realize([&x])?;
assert_eq!(x.shape(), [1, 2, 1, 3, 1]);
assert_eq!(x, [[[[[2i32], [3], [5]]], [[[4], [1], [1]]]]]);
Ok(())
}
#[test]
fn pool() -> Result<(), ZyxError> {
let mut x = Tensor::from((0..9).collect::<Vec<i32>>()).reshape((3, 3))?;
x = x.pool([2, 2], 1, 1)?;
assert_eq!(
x,
[[[[0, 1], [3, 4]], [[1, 2], [4, 5]]], [[[3, 4], [6, 7]], [[4, 5], [7, 8]]]]
);
Ok(())
}
#[test]
fn cumsum() -> Result<(), ZyxError> {
let mut x = Tensor::from((0..9).collect::<Vec<i32>>()).reshape((3, 3))?;
x = x.cumsum(1)?;
assert_eq!(x, [[0, 1, 3], [3, 7, 12], [6, 13, 21]]);
Ok(())
}
#[test]
fn arange_3() -> Result<(), ZyxError> {
let x = Tensor::arange(0, 10, 2)?;
assert_eq!(x, [0, 2, 4, 6, 8]);
Ok(())
}
#[test]
fn const_() -> Result<(), ZyxError> {
let x = Tensor::from([[3f32, 4., 2.], [4., 3., 2.]]);
let y = 1f32 + x;
assert_eq!(y, [[4f32, 5., 3.], [5., 4., 3.]]);
Ok(())
}
#[test]
fn graph_shapes() -> Result<(), ZyxError> {
let x: Tensor = 2.into();
let y = x.expand([1, 1])?;
assert_eq!(y, [[2]]);
Ok(())
}
#[test]
fn cat() -> Result<(), ZyxError> {
let a = Tensor::from([[1, 2], [3, 4]]);
let b = Tensor::from([[5, 6], [7, 8]]);
let c = Tensor::cat([&a, &b], 0)?;
assert_eq!(c, [[1, 2], [3, 4], [5, 6], [7, 8]]);
let c = Tensor::cat([&a, &b], 1)?;
assert_eq!(c, [[1, 2, 5, 6], [3, 4, 7, 8]]);
Ok(())
}
#[test]
fn pad_zeros() -> Result<(), ZyxError> {
let x = Tensor::from([[2, 3], [4, 5]]);
let x = x.rpad_zeros([(4, 3), (1, 2)])?;
assert_eq!(
x,
[
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 2, 3, 0, 0, 0],
[0, 0, 0, 0, 4, 5, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0]
]
);
Ok(())
}
#[test]
fn one_hot() -> Result<(), ZyxError> {
if !Tensor::supports(DType::I64) {
return Ok(());
}
let x = Tensor::from([2, 3, 4]);
let y = x.one_hot(4);
assert_eq!(y, [[0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 0]]);
Ok(())
}
#[test]
fn ones() {
let x = Tensor::ones([2, 3], DType::I32);
assert_eq!(x, [[1i32, 1, 1], [1, 1, 1]]);
}
#[test]
fn graph_node_reuse() {
let x = Tensor::from([4, 2, 3]);
let y = Tensor::from([4, 2, 3]);
let a = x + y;
assert_eq!(a, [8, 4, 6]);
drop(a);
let x = Tensor::from([4, 2, 3]);
let y = Tensor::from([4, 2, 3]);
let b = x + y;
assert_eq!(b, [8, 4, 6]);
}
#[test]
fn slice1() {
let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
assert_eq!(x.slice((.., 2..3)).unwrap(), [[1], [4]]);
}
#[test]
fn slicing_comprehensive() {
let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
assert_eq!(x.slice((.., 2..3)).unwrap(), [[1], [4]]);
assert_eq!(x.slice((.., ..)).unwrap(), [[2, 3, 1], [2, 1, 4]]);
assert_eq!(x.slice((0..1, ..)).unwrap(), [[2, 3, 1]]);
assert_eq!(x.slice((1..2, 0..1)).unwrap(), [[2]]);
assert_eq!(x.slice(1..2).unwrap(), [[2, 1, 4]]);
let y = Tensor::from([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]);
assert_eq!(y.slice((.., .., 1..3)).unwrap(), [[[2, 3], [5, 6]], [[8, 9], [11, 12]]]);
assert_eq!(y.slice((.., 1..2, ..)).unwrap(), [[[4, 5, 6]], [[10, 11, 12]]]);
assert_eq!(y.slice((0..1, ..)).unwrap(), [[[1, 2, 3], [4, 5, 6]]]);
assert_eq!(y.slice(1..2).unwrap(), [[[7, 8, 9], [10, 11, 12]]]);
assert_eq!(y.slice((.., 0..1)).unwrap(), [[[1, 2, 3]], [[7, 8, 9]]]);
let z = Tensor::from([
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
[[[9, 10], [11, 12]], [[13, 14], [15, 16]]],
]);
assert_eq!(
z.slice((.., .., .., 1..2)).unwrap(),
[[[[2], [4]], [[6], [8]]], [[[10], [12]], [[14], [16]]]]
);
assert_eq!(z.slice((0..1, .., 0..1, ..)).unwrap(), [[[[1, 2]], [[5, 6]]]]);
assert_eq!(z.slice(1..2).unwrap(), [[[[9, 10], [11, 12]], [[13, 14], [15, 16]]]]);
assert!(z.slice((.., .., .., .., ..)).is_err() || z.slice((.., .., .., .., ..)).is_ok());
}
#[test]
fn rslicing_comprehensive() {
let x = Tensor::from([[4, 6, 8], [5, 7, 9]]);
assert_eq!(x.rslice((.., 0..1)).unwrap(), [[4, 6, 8]]);
assert_eq!(x.rslice((.., 1..2)).unwrap(), [[5, 7, 9]]);
assert_eq!(x.rslice((.., ..)).unwrap(), [[4, 6, 8], [5, 7, 9]]);
assert_eq!(x.rslice(0..1).unwrap(), [[4], [5]]);
assert_eq!(x.rslice((0..1, 0..1)).unwrap(), [[4]]);
assert_eq!(x.rslice(1..2).unwrap(), [[6], [7]]);
let y = Tensor::from([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]);
assert_eq!(y.rslice((.., .., 0..1)).unwrap(), [[[2, 4, 6], [8, 10, 12]]]);
assert_eq!(y.rslice((.., .., 1..2)).unwrap(), [[[14, 16, 18], [20, 22, 24]]]);
assert_eq!(y.rslice((.., 0..1, ..)).unwrap(), [[[2, 4, 6]], [[14, 16, 18]]]);
assert_eq!(y.rslice((1..2, .., ..)).unwrap(), [[[4], [10]], [[16], [22]]]);
assert_eq!(y.rslice(1..2).unwrap(), [[[4], [10]], [[16], [22]]]);
assert_eq!(y.rslice((.., 1..2)).unwrap(), [[[8, 10, 12]], [[20, 22, 24]]]);
let z = Tensor::from([
[[[3, 5], [7, 9]], [[11, 13], [15, 17]]],
[[[19, 21], [23, 25]], [[27, 29], [31, 33]]],
]);
assert_eq!(
z.rslice((.., .., .., 0..1)).unwrap(),
[[[[3, 5], [7, 9]], [[11, 13], [15, 17]]]]
);
assert_eq!(
z.rslice((.., .., .., 1..2)).unwrap(),
[[[[19, 21], [23, 25]], [[27, 29], [31, 33]]]]
);
assert_eq!(z.rslice((0..1, .., 0..1, ..)).unwrap(), [[[[3], [7]]], [[[19], [23]]]]);
assert_eq!(
z.rslice(1..2).unwrap(),
[[[[5], [9]], [[13], [17]]], [[[21], [25]], [[29], [33]]]]
);
assert!(z.rslice((.., .., .., .., ..)).is_err() || z.rslice((.., .., .., .., ..)).is_ok());
}
#[test]
fn split1() {
let x = Tensor::from([[2, 3, 1], [2, 1, 4]]);
let tensors = x.split([2, 1], 1).unwrap();
assert_eq!(tensors[0], [[2, 3], [2, 1]]);
assert_eq!(tensors[1], [[1], [4]]);
}
#[test]
fn more_padding() -> Result<(), ZyxError> {
let t1 = Tensor::from([1, 2, 3, 4, 5]);
let padded1 = t1.pad_zeros([(2, 3)])?;
assert_eq!(padded1, [0, 0, 1, 2, 3, 4, 5, 0, 0, 0]);
let rpadded1 = t1.rpad_zeros([(2, 3)])?;
assert_eq!(rpadded1, [0, 0, 1, 2, 3, 4, 5, 0, 0, 0]);
let cropped1 = t1.pad_zeros([(-1, -2)])?;
assert_eq!(cropped1, [2, 3]);
let rcropped1 = t1.rpad_zeros([(-1, -2)])?;
assert_eq!(rcropped1, [2, 3]);
let t2 = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let padded2 = t2.pad_zeros([(1, 1), (2, 1)])?;
assert_eq!(
padded2,
[[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0], [0, 0, 4, 5, 6, 0], [0, 0, 0, 0, 0, 0]]
);
let rpadded2 = t2.rpad_zeros([(2, 1), (1, 1)])?;
assert_eq!(
rpadded2,
[[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0], [0, 0, 4, 5, 6, 0], [0, 0, 0, 0, 0, 0]]
);
let err = t1.rpad_zeros([(-10, 0)]).unwrap_err();
assert!(matches!(err, ZyxError::ShapeError(_)));
Ok(())
}
#[test]
fn partial_padding() -> Result<(), ZyxError> {
let t4 = Tensor::from([
[[[1i32, 2], [3, 4]], [[5, 6], [7, 8]]],
[[[9, 10], [11, 12]], [[13, 14], [15, 16]]],
]);
let padded = t4.pad_zeros([(1, 1), (1, 0)])?;
assert_eq!(padded.shape(), vec![4, 3, 2, 2]);
assert_eq!(
padded,
[
[[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]],
[[[0, 0], [0, 0]], [[1, 2], [3, 4]], [[5, 6], [7, 8]]],
[[[0, 0], [0, 0]], [[9, 10], [11, 12]], [[13, 14], [15, 16]]],
[[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]
]
);
let rpadded = t4.rpad_zeros([(1, 0), (1, 1)])?;
assert_eq!(rpadded.shape(), vec![2, 2, 4, 3]);
assert_eq!(
rpadded,
[
[
[[0, 0, 0], [0, 1, 2], [0, 3, 4], [0, 0, 0]],
[[0, 0, 0], [0, 5, 6], [0, 7, 8], [0, 0, 0]]
],
[
[[0, 0, 0], [0, 9, 10], [0, 11, 12], [0, 0, 0]],
[[0, 0, 0], [0, 13, 14], [0, 15, 16], [0, 0, 0]]
]
]
);
Ok(())
}
#[test]
fn split2() -> Result<(), ZyxError> {
let a = Tensor::arange(0, 10, 1)?.reshape([5, 2])?;
let x = a.split([2, 2, 1], 0)?;
assert_eq!(x[0], [[0, 1], [2, 3]]);
assert_eq!(x[1], [[4, 5], [6, 7]]);
assert_eq!(x[2], [[8, 9]]);
let x = a.split([1, 4], 0)?;
assert_eq!(x[0], [[0, 1]]);
assert_eq!(x[1], [[2, 3], [4, 5], [6, 7], [8, 9]]);
Ok(())
}
#[test]
fn matmul_disk() -> Result<(), ZyxError> {
if !Tensor::supports(DType::I64) {
return Ok(());
}
let mut xyz: std::collections::HashMap<String, Tensor> = Tensor::load("./tests/xyz2.safetensors")?;
let z = xyz.remove("z").unwrap();
let y = xyz.remove("y").unwrap();
let x = xyz.remove("x").unwrap();
let dataz: Vec<i64> = z.try_into()?;
let zz = x.matmul(y)?;
let datazz: Vec<i64> = zz.try_into()?;
for (i, (x, y)) in dataz.iter().zip(datazz).enumerate() {
assert!(x.is_equal(y), "{x} != {y} at index {i}");
}
Ok(())
}
#[test]
fn softmax_1() -> Result<(), ZyxError> {
let x = Tensor::from([2f32, 4., 3.]);
let y = x.softmax([])?;
assert_eq!(
y,
[0.09003056585788726807f32, 0.66524088382720947266, 0.24472846090793609619,]
);
Ok(())
}
#[test]
fn dot_pad() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
let y = Tensor::from([[2, 3], [1, 2], [4, 1]]);
x = x.dot(y)?.rpad_zeros([(2, 1)])?;
assert_eq!(x, [[0, 0, 11, 13, 0], [0, 0, 12, 15, 0]]);
Ok(())
}
#[test]
#[should_panic]
fn t3() {
let x = Tensor::randn([1024, 1024], DType::F32)
.unwrap()
.expand([1024, 1024, 1024, 1024, 1024, 1024])
.unwrap();
Tensor::realize([&x]).unwrap();
}
#[test]
fn layer_norm() -> Result<(), ZyxError> {
let weight = Some(Tensor::from([4f32, 5., 1., 2.]));
let d_dims = weight.as_ref().unwrap().rank();
let bias: Option<Tensor> = None;
let eps = 0.00001f32;
let x = Tensor::from([[3, 5, 2, 1], [6, 1, 4, 2]]).cast(DType::F32);
let axes = -(d_dims as i32)..=-1;
let eps = Tensor::from(eps).cast(x.dtype());
let a = &x - x.mean_keepdim(axes.clone())?;
let b = (x.var_keepdim(axes)? + eps).sqrt();
let mut x = a / b;
if let Some(w) = &weight {
x = x * w;
}
if let Some(b) = &bias {
x = x + b;
}
assert_eq!(
x,
[
[0.585539f32, 6.587314, -0.439154, -2.049387],
[4.960858, -5.073606, 0.338240, -1.127468]
]
);
Ok(())
}
#[test]
fn multiple_stores() -> Result<(), ZyxError> {
let x = Tensor::from([[3f32, 4., 2.], [5., 4., 1.]]);
let y = x.ln();
let z = y.tanh();
Tensor::realize([&y, &z])?;
assert_eq!(
z,
[
[0.8000000119f32, 0.8823529482, 0.6000000238],
[0.9230769277, 0.8823529482, 0.0000000000]
]
);
Ok(())
}
#[test]
fn repeat1() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
x = x.repeat([2, 3, 1])?;
assert_eq!(
x,
[
[[2, 3, 1], [2, 4, 1], [2, 3, 1], [2, 4, 1], [2, 3, 1], [2, 4, 1]],
[[2, 3, 1], [2, 4, 1], [2, 3, 1], [2, 4, 1], [2, 3, 1], [2, 4, 1]]
]
);
Ok(())
}
#[test]
fn mix_2() {
let x = Tensor::from([[2f32, 3.], [4., 5.]]);
let y = x.t();
let z = x.exp().cast(DType::I32);
Tensor::realize([&y, &z]).unwrap();
assert_eq!(z, [[7i32, 20], [54, 148]]);
}
#[cfg(not(feature = "wgpu"))]
#[test]
fn rand_get() -> Result<(), ZyxError> {
Tensor::manual_seed(69420);
let x = Tensor::rand([3, 12], DType::U8)?;
let x = x.slice((.., 8..=-2))?;
assert_eq!(x, [[41u8, 171, 236], [212, 222, 77], [16, 125, 60]]);
Ok(())
}
#[cfg(not(feature = "wgpu"))]
#[test]
fn gather_test() -> Result<(), ZyxError> {
let x = Tensor::from([[10u16, 20, 30, 40, 50], [11, 21, 31, 41, 51], [12, 22, 32, 42, 52]]);
let indices = Tensor::from([
[0u16, 2, 4], [1, 3, 0], [4, 1, 2], ]);
let gathered = x.gather(1, &indices)?;
let expected = [[10u16, 30, 50], [21, 41, 11], [52, 22, 32]];
assert_eq!(gathered, expected);
Ok(())
}
#[test]
fn eye1() {
let x = Tensor::eye(8, DType::I32);
assert_eq!(
x,
[
[1i32, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
]
);
}
#[allow(unused)]
#[test]
fn bench_mm1() -> Result<(), ZyxError> {
const N: usize = 1024;
let dtype = zyx::DType::F32;
let x = Tensor::rand([N as u64, N as u64], dtype)?;
let y = Tensor::rand([N as u64, N as u64], dtype)?;
let x_data: Vec<f32> = x.clone().try_into()?;
let y_data: Vec<f32> = y.clone().try_into()?;
let z = x.matmul(&y)?;
let z_data: Vec<f32> = z.try_into()?;
let expected = matmul(&x_data, &y_data, N, N, N);
for (x, y) in z_data.into_iter().zip(expected) {
if !x.is_equal(y) {
panic!("Wrong matmul");
}
}
Ok(())
}
#[test]
fn double_vec() -> Result<(), ZyxError> {
let x = Tensor::from(vec![vec![4, 1, 2], vec![4, 6, 2]]);
assert_eq!(x.shape(), [2, 3]);
Ok(())
}
#[test]
fn binary_y_depends_on_x() -> Result<(), ZyxError> {
let z = {
let x = Tensor::from([[2, 4, 1], [3, 2, 4]]).cast(DType::F32);
let x = x
.exp2()
.log2()
.exp2()
.log2()
.exp2()
.log2()
.exp2()
.log2()
.exp2()
.log2()
.exp2()
.log2()
.exp2()
.log2()
.exp2()
.log2();
let y = x.permute([1, 0]).unwrap();
let z = x.reshape(6).unwrap() + y.reshape(6).unwrap() + x.reshape(6).unwrap();
z.exp2().log2()
};
assert_eq!(z, [6f32, 11., 6., 8., 5., 12.]);
Ok(())
}
#[test]
fn dot5() {
let x = Tensor::from([[2, 3, 1], [3, 4, 1]]);
let y = Tensor::from([[2, 3], [2, 1], [4, 1]]);
let x = x.dot(y).unwrap();
assert_eq!(x, [[14, 10], [18, 14]]);
}
#[test]
fn conv1() -> Result<(), ZyxError> {
let t = Tensor::arange(0f32, 9., 1.)?.reshape([1, 1, 3, 3])?;
let w = Tensor::ones([1, 1, 2, 2], DType::F32);
let x = t.conv(&w, None, 1, 1, 1, 0)?;
assert_eq!(x, [[[[8f32, 12.,], [20., 24.]]]]);
Ok(())
}
#[test]
fn graph_tensor_ordering() -> Result<(), ZyxError> {
let z2 = {
let x = Tensor::from([3f32, 4., 2.]); let z1 = x.exp2() + x.log2(); z1.exp2() };
let _z3 = {
z2.exp2() * z2 };
Ok(())
}
#[test]
fn rope_3() -> Result<(), ZyxError> {
let z = {
let xs = Tensor::from([[1f32, 4., 2., 4., 4., 3.], [4., 2., 4., 4., 3., 4.]]).reshape([1, 1, 2, 6])?;
let sin = Tensor::from([1f32, 4., 2., 4., 4., 3.]).reshape([2, 3])?;
let cos = Tensor::from([1f32, 4., 2., 4., 4., 3.]).reshape([2, 3])?;
let [d] = xs.rdims()?;
let sin_freqs = sin.squeeze([0, 1]);
let cos_freqs = cos.squeeze([0, 1]);
let a = xs.slice((.., .., .., ..d / 2)).unwrap();
let b = -xs.slice((.., .., .., d / 2..)).unwrap();
let ro = a.clone() * cos_freqs.clone() - b.clone() * sin_freqs.clone();
assert_eq!(ro, [[[[5f32, 32., 10.], [32., 20., 24.]]]]);
let co = a * sin_freqs + b * cos_freqs;
Tensor::cat([&co, &ro], -1).unwrap()
};
assert_eq!(z.cast(DType::I32), [[[[-3i32, 0, -2, 5, 32, 10], [0, -4, 0, 32, 20, 24]]]]);
Ok(())
}
#[test]
fn rope_4() -> Result<(), ZyxError> {
let xs = Tensor::from([1f32, 4., 2., 4., 4., 3., 4., 2., 4., 4., 3., 4.]).reshape([1, 1, 2, 6])?;
let sin = Tensor::from([1f32, 4., 2., 4., 4., 3.]).reshape([2, 3])?;
let cos = Tensor::from([1f32, 4., 2., 4., 4., 3.]).reshape([2, 3])?;
let z = xs.rope(&cos, &sin)?.cast(DType::I32);
assert_eq!(z, [[[[-3i32, 0, -2, 5, 32, 10], [0, -4, 0, 32, 20, 24]]]]);
Ok(())
}
#[test]
fn complex_movement_reduce() -> Result<(), ZyxError> {
let x = Tensor::from([[[2f32, 3.]], [[4., 5.]]])
.expand([2, 3, 2])?
.exp()
.ln()
.reshape([2, 3, 2, 1])?;
let y = Tensor::from([[2f32, 3., 1.], [4., 3., 2.]])
.reshape([2, 3, 1, 1])?
.expand([2, 3, 2, 1])?;
let z = (&x + &y).expand([2, 3, 2, 2])?.sum([3, 0])?;
let z = z.exp().ln().permute([1, 0])?.sum([0])?;
assert_eq!(z, [52f32, 52., 40.]);
Ok(())
}
#[test]
fn mean1() -> Result<(), ZyxError> {
let x = Tensor::from([[1i32, 2, 3], [4, 5, 6]]);
let mean = x.sum([1])? * 0.3333333333333f32;
let y = x - mean.reshape([2, 1])?;
assert_eq!(y, [[-1f32, 0., 1.], [-1., 0., 1.]]);
Ok(())
}
#[test]
fn var1() -> Result<(), ZyxError> {
let x = Tensor::from([[1f32, 2., 3.], [4., 5., 6.]]);
let [n] = x.dims()?;
let mean = x.mean_keepdim([0])?;
let x = x - mean;
let squared = &x * &x;
let summed = squared.sum([0])?;
let y = summed / n as u32;
assert_eq!(y, [2.25f32, 2.25, 2.25]);
Ok(())
}
#[test]
fn mean2() -> Result<(), ZyxError> {
let x = Tensor::from([[1i32, 2, 3], [4, 5, 6]]);
let mean = x.mean_keepdim([1])?;
let y = x - mean;
assert_eq!(y, [[-1i32, 0, 1], [-1, 0, 1]]);
Ok(())
}
#[test]
fn var2() -> Result<(), ZyxError> {
let x = Tensor::from([[1f32, 2., 3.], [4., 5., 6.]]);
let [_, n] = x.dims()?;
let mean = x.mean_keepdim([1])?;
let x = x - mean;
let squared = &x * &x;
let summed = squared.sum([1])?;
let y = summed / n as u32;
assert_eq!(y, [0.666666f32, 0.666666]);
Ok(())
}
#[test]
fn var3() -> Result<(), ZyxError> {
let x = Tensor::from([[1f32, 2., 3.], [4., 5., 6.]]);
let y = x.var_correction([1], 0)?;
assert_eq!(y, [0.666666f32, 0.666666]);
Ok(())
}
#[test]
fn var4() -> Result<(), ZyxError> {
let x = Tensor::from([[1f32, 2., 3.], [4., 5., 6.]]);
let y = x.var_correction([0], 0)?;
assert_eq!(y, [2.25f32, 2.25, 2.25]);
let y = x.var_correction([1], 0)?;
assert_eq!(y, [0.666666f32, 0.666666]);
Ok(())
}
#[test]
fn softmax_2() -> Result<(), ZyxError> {
let x = Tensor::from([[2f32, 4., 3.], [4., 2., 3.]]);
let y = x.softmax([])?;
assert_eq!(
y,
[
[0.0450152867f32, 0.3326204717, 0.1223642379],
[0.3326204717, 0.0450152867, 0.1223642379]
]
);
let y = x.softmax([0])?;
assert_eq!(y, [[0.1192029193f32, 0.8807970285, 0.5], [0.8807970285, 0.1192029193, 0.5]]);
let y = x.softmax([1])?;
assert_eq!(
y,
[
[0.0900305659f32, 0.6652408838, 0.2447284609],
[0.6652408838, 0.0900305659, 0.2447284609]
]
);
Ok(())
}
#[test]
fn causal_self_attention() -> Result<(), ZyxError> {
let y = {
let dtype = DType::F32;
let n_embd = 4;
let n_head = 4;
let c_attn_weight = Tensor::from([
[3, 1, 2, 3, 1, 2, 5, 4, 2, 3, 1, 3],
[1, 1, 2, 3, 1, 2, 5, 4, 2, 3, 1, 3],
[3, 1, 5, 3, 1, 2, 5, 4, 2, 3, 1, 3],
[3, 1, 2, 3, 1, 2, 5, 8, 2, 3, 1, 3],
])
.t()
.cast(dtype);
let x = Tensor::from([[[1, 0, 4, 2], [2, 5, 0, 1], [0, 8, 1, 0], [5, 1, 0, 0]]]).cast(dtype);
let [b, t, c] = x.shape()[..] else {
return Err(ZyxError::ShapeError("x must have exactly 3 dims, b, t, c".into()));
};
let mut splits = x.dot(c_attn_weight.t())?.split([n_embd, n_embd, n_embd], 2)?;
let mut v = splits.pop().unwrap();
let mut k = splits.pop().unwrap();
let mut q = splits.pop().unwrap();
k = k.reshape([b, t, n_head, c / n_head])?.transpose(1, 2)?;
q = q.reshape([b, t, n_head, c / n_head])?.transpose(1, 2)?;
v = v.reshape([b, t, n_head, c / n_head])?.transpose(1, 2)?;
let mut att = q.dot(k.t())? * (1f32 / (*k.shape().last().unwrap() as f32).sqrt());
att = att.softmax([-1])?;
let mut y = att.dot(v)?;
y = y.transpose(1, 2)?.reshape([b, t, c])?;
y
};
assert_eq!(
y,
[[
[18f32, 27., 9., 24.],
[18., 27., 9., 24.],
[18., 27., 9., 24.],
[18., 27., 9., 24.]
]]
);
Ok(())
}
#[test]
fn dot6() -> Result<(), ZyxError> {
let mut x = Tensor::from([2i32, 3, 1]);
let w = Tensor::from([[2i32, 3, 2], [2, 1, 1], [4, 1, 4]]);
for _ in 0..10 {
x = x.matmul(&w)?;
}
assert_eq!(x, [492004322i32, 323660910, 445342573]);
Ok(())
}
#[test]
fn dot3() -> Result<(), ZyxError> {
let x = Tensor::from([2i32, 3, 1]);
let w = Tensor::from([[2i32, 3, 2], [2, 1, 1], [4, 1, 4]]);
let b = Tensor::from([2i32, 3, 5]);
let x = x.dot(&w)? + &b;
Tensor::realize([&x])?;
Ok(())
}
#[test]
fn dot4() -> Result<(), ZyxError> {
let mut x = Tensor::from([2i32, 3, 1]);
let w = Tensor::from([[2i32, 3, 2], [2, 1, 1], [4, 1, 4]]);
let b = Tensor::from([2i32, 3, 5]);
for _ in 0..10 {
x = x.dot(&w)? + &b;
}
assert_eq!(x, [671627020i32, 441824135, 607929878]);
Ok(())
}
#[test]
fn cross_entropy() -> Result<(), ZyxError> {
let x = Tensor::from([[2, 3, 4], [5, 6, 7]]).cast(DType::F32);
let target = Tensor::from([[0, 1, 0], [0, 0, 1]]).cast(DType::F32);
let m = &x - x.max_keepdim([1])?;
let neg_log2_softmax = m.exp().sum_keepdim([1])?.ln() - m;
let ce = neg_log2_softmax * target;
assert_eq!(ce, [[0.000000f32, 1.407606, 0.000000], [0.000000, 0.000000, 0.407606]]);
Ok(())
}
#[test]
fn test_padding_on_elementwise_kernel() {
let t = Tensor::from([2, 3, 4]);
let padded = t.pad([(1, 1)], 0).unwrap();
let result = padded + 1;
assert_eq!(result.shape(), &[5]);
assert_eq!(result.slice(1).unwrap(), 3);
}
#[test]
fn test_expand_on_elementwise_kernel() {
if !Tensor::supports(DType::I64) || !Tensor::supports(DType::F64) {
return;
}
let t = Tensor::from([2, 3, 4]);
let expanded = t.expand([3, 3]).unwrap();
let result = expanded + 1.0;
assert_eq!(result.shape(), &[3, 3]);
assert_eq!(result.slice((1, 1)).unwrap(), 4.0);
}
#[test]
fn test_reshape_on_elementwise_kernel() {
if !Tensor::supports(DType::I64) {
return;
}
let t = Tensor::from([2, 3, 4]);
let reshaped = t.reshape([3, 1]).unwrap();
let result = reshaped * 2.0;
assert_eq!(result.shape(), &[3, 1]);
assert_eq!(result.slice((2, 0)).unwrap(), 8.0);
}
#[test]
fn test_permute_on_elementwise_kernel() {
let t = Tensor::from([[[1.0f32, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);
let permuted = t.permute([2, 0, 1]).unwrap();
let result = permuted + 1.0f32;
assert_eq!(result.shape(), &[2, 2, 2]);
let value: f32 = result.slice((1, 0, 1)).unwrap().item();
assert_eq!(value, 5.0f32);
}
#[test]
fn test_padding_on_reduce_kernel() {
let t = Tensor::from([[1.0f32, 2.0], [3.0, 4.0]]);
let padded = t.pad([(0, 0), (1, 1)], 0.0f32).unwrap();
let reduced = padded.sum([0]).unwrap();
assert_eq!(reduced.shape(), &[4]);
assert_eq!(reduced.slice(0).unwrap(), 0.0f32);
assert_eq!(reduced.slice(1).unwrap(), 4.0f32);
assert_eq!(reduced.slice(2).unwrap(), 6.0f32);
assert_eq!(reduced.slice(3).unwrap(), 0.0f32);
}
#[test]
fn test_expand_on_reduce_kernel() {
let t = Tensor::from([[1.0f32], [2.0], [3.0]]);
let expanded = t.expand([3, 2]).unwrap();
let reduced = expanded.mean([1]).unwrap();
assert_eq!(reduced.shape(), &[3]);
assert_eq!(reduced.slice(1).unwrap(), 2.0f32);
}
#[test]
fn test_reshape_on_reduce_kernel() {
let t = Tensor::from([[1.0f32, 2.0], [3.0, 4.0]]);
let reshaped = t.reshape([4]).unwrap();
let reduced = reshaped.sum([0]).unwrap();
assert_eq!(reduced.shape(), &[1]);
assert_eq!(reduced.item::<f32>(), 10.0f32);
}
#[test]
fn test_permute_on_reduce_kernel() {
let t = Tensor::from([[[1.0f32, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);
let permuted = t.permute([1, 2, 0]).unwrap();
let reduced = permuted.sum([2]).unwrap();
assert_eq!(reduced.shape(), &[2, 2]);
assert_eq!(reduced.slice((0, 0)).unwrap(), 6.0f32);
}
#[test]
fn arange_1() -> Result<(), ZyxError> {
let x = Tensor::arange(0, 784 * 7, 1)?.cast(DType::F32).exp2().sin();
Tensor::realize([&x]).unwrap();
Ok(())
}
#[test]
fn arange_2() {
let x = Tensor::arange(0, 2, 1).unwrap().exp2().sin();
Tensor::realize([&x]).unwrap();
}