zyx 0.15.6

Zyx machine learning library
Documentation
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only

use zyx::{Tensor, ZyxError};

#[test]
fn reshape_1() -> Result<(), ZyxError> {
    let mut x = Tensor::from([[4, 5, 2, 1], [3, 4, 1, 4]]);
    x = x.reshape([8, 1])?;
    x = x.reshape([1, 2, 1, 4])?;
    x = x.reshape([4, 2])?;
    assert_eq!(x, [[4, 5], [2, 1], [3, 4], [1, 4]]);
    Ok(())
}

#[test]
fn reshape_permute_1() -> Result<(), ZyxError> {
    let mut x = Tensor::from([[4, 5, 2, 1], [3, 4, 1, 4]]);
    x = x.reshape([8, 1])?;
    x = x.reshape([1, 2, 1, 4])?.permute([2, 3, 1, 0])?;
    x = x.reshape([4, 2])?.cast(zyx::DType::F32).exp2().cast(zyx::DType::I32);
    assert_eq!(x, [[16, 8], [32, 16], [4, 2], [2, 16]]);
    Ok(())
}

#[test]
fn expand_1() -> Result<(), ZyxError> {
    let a = Tensor::from([[1, 2], [3, 4]]).reshape([1, 1, 1, 4])?;
    let b = Tensor::from([[5, 6], [7, 8]]).reshape([1, 1, 4, 1])?;
    let c = a + b;
    assert_eq!(c, [[[[6, 7, 8, 9], [7, 8, 9, 10], [8, 9, 10, 11], [9, 10, 11, 12]]]]);
    Ok(())
}

#[test]
fn permute_2() -> Result<(), ZyxError> {
    let x = Tensor::from([[4, 5, 2, 1], [3, 4, 1, 4]]);
    let y = x.permute([1, 0])?;
    assert_eq!(y, [[4, 3], [5, 4], [2, 1], [1, 4]]);
    Ok(())
}

#[test]
fn pad_1() -> Result<(), ZyxError> {
    let a = Tensor::from([[1, 2], [3, 4]]);
    let c = a.pad_zeros([(0, 2), (0, 0)])?;
    assert_eq!(c, [[1, 2], [3, 4], [0, 0], [0, 0]]);
    Ok(())
}

#[test]
fn pad_2() -> Result<(), ZyxError> {
    let a = Tensor::from([[1i32, 2], [3, 4]]).reshape([1, 1, 2, 2])?;
    let b = Tensor::from([[5, 6], [7, 8]]).reshape([1, 1, 1, 4])?;
    let c = a.pad_zeros([(0, 0), (0, 0), (0, 2), (0, 2)])? + b;
    assert_eq!(c, [[[[6i32, 8, 7, 8], [8, 10, 7, 8], [5, 6, 7, 8], [5, 6, 7, 8]]]]);
    Ok(())
}

#[test]
fn rope_1() -> Result<(), ZyxError> {
    let x = Tensor::from([1, 2, 3, 4, 5, 6, 7, 8]).reshape([2, 4])?;
    let sin_freq = Tensor::from([[2, 3], [3, 1]]);
    let cos_freq = Tensor::from([[2, 3], [3, 1]]);

    let a = x.rpad_zeros([(-2, 0)])?;
    let b = -x.rpad_zeros([(0, -2)])?;
    let z = &a * &sin_freq - &b * &cos_freq;
    let z2 = a * sin_freq + b * cos_freq;
    let z3 = z.rpad_zeros([(0, 2)])? + z2.rpad_zeros([(2, 0)])?;
    drop(x);
    //drop(z);
    drop(z2);
    //Tensor::plot_graph([], "graph")?;
    Tensor::realize([&z, &z3])?;
    assert_eq!(z, [[8, 18], [36, 14]]);
    assert_eq!(z3, [[8, 18, 4, 6], [36, 14, 6, 2]]);
    Ok(())
}