use zyx::{Tensor, ZyxError};
#[test]
fn matmul() -> Result<(), ZyxError> {
let x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
let y = Tensor::from([[2, 4], [3, 1], [5, 1]]);
let z = x.dot(y)?;
assert_eq!(z, [[31, 15], [22, 10]]);
Ok(())
}
#[test]
fn pad_reduce() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
x = x.sum(1)?;
x = x.pad_zeros([(0, 1)])?;
assert_eq!(x, [9, 7, 0]);
Ok(())
}
#[test]
fn permute_pad() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
x = x.pad_zeros([(1, 0)])?.t();
assert_eq!(x, [[0, 0], [2, 1], [4, 5], [3, 1]]);
Ok(())
}
#[test]
fn expand_reduce() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
x = x.sum(1)?;
let y = x.expand([2, 2])?;
x = x.reshape([2, 1])?.expand([2, 2])?;
Tensor::realize([&x, &y])?;
assert_eq!(y, [[9, 7], [9, 7]]);
assert_eq!(x, [[9, 9], [7, 7]]);
Ok(())
}
#[test]
fn pad_reshape_expand() -> Result<(), ZyxError> {
let mut x = Tensor::from([[2, 4, 3, 3, 4], [1, 2, 1, 5, 1]]);
x = x.pad_zeros([(1, 0), (2, 1)])?;
x = x.reshape([2, 1, 3, 5])?;
x = x.expand([2, 2, 3, 5])?;
assert_eq!(
x,
[
[
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 2, 4]],
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 2, 4]]
],
[
[[3, 3, 4, 0, 1], [2, 1, 5, 1, 0], [0, 0, 0, 0, 0]],
[[3, 3, 4, 0, 1], [2, 1, 5, 1, 0], [0, 0, 0, 0, 0]]
]
]
);
Ok(())
}
#[test]
fn pool() -> Result<(), ZyxError> {
let mut x = Tensor::from((0..9).collect::<Vec<i32>>()).reshape((3, 3))?;
x = x.pool([2, 2], 1, 1)?;
assert_eq!(
x,
[
[[[0, 1], [3, 4]], [[1, 2], [4, 5]]],
[[[3, 4], [6, 7]], [[4, 5], [7, 8]]]
]
);
Ok(())
}
#[test]
fn cumsum() -> Result<(), ZyxError> {
let mut x = Tensor::from((0..9).collect::<Vec<i32>>()).reshape((3, 3))?;
x = x.cumsum(1)?;
assert_eq!(x, [[0, 1, 3], [3, 7, 12], [6, 13, 21]]);
Ok(())
}
#[test]
fn arange() -> Result<(), ZyxError> {
let x = Tensor::arange(0, 10, 2)?;
assert_eq!(x, [0, 2, 4, 6, 8]);
Ok(())
}
#[test]
fn const_() -> Result<(), ZyxError> {
let x = Tensor::from([[3f32, 4., 2.], [4., 3., 2.]]);
let mut y = Tensor::constant(1) + x; println!("{y}'");
y = y.ln();
println!("{y}'");
Ok(())
}
#[test]
fn graph_shapes() -> Result<(), ZyxError> {
let x = Tensor::constant(2);
let y = x.expand([1, 1])?;
println!("{y}");
Ok(())
}
#[test]
fn uni_matmul() -> Result<(), ZyxError> {
let x = Tensor::uniform([5, 5], -1f32..2f32)?;
let y = Tensor::uniform([5, 5], -1f32..5f32)?;
let z = x.dot(y)?;
println!("{z}");
Ok(())
}
#[test]
fn cat() -> Result<(), ZyxError> {
let a = Tensor::from([[1, 2], [3, 4]]);
let b = Tensor::from([[5, 6], [7, 8]]);
let c = Tensor::cat([&a, &b], 0)?;
assert_eq!(c, [[1, 2], [3, 4], [5, 6], [7, 8]]);
let c = Tensor::cat([&a, &b], 1)?;
assert_eq!(c, [[1, 2, 5, 6], [3, 4, 7, 8]]);
Ok(())
}
#[test]
fn matmul_1024() -> Result<(), ZyxError> {
let mut xyz: Vec<Tensor> = Tensor::load("xyz.safetensors")?;
let z = xyz.pop().unwrap();
let y = xyz.pop().unwrap();
let x = xyz.pop().unwrap();
println!("{:?}", x.shape());
println!("{:?}", y.shape());
let dataz: Vec<f32> = z.try_into()?;
let zz = x.matmul(y)?;
let datazz: Vec<f32> = zz.try_into()?;
for (x, y) in dataz.iter().zip(datazz) {
assert!((x - y).abs() < 0.01);
}
Ok(())
}