zyx 0.15.6

Zyx machine learning library
Documentation
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only

use zyx::{Tensor, ZyxError};

#[test]
fn sum_1() -> Result<(), ZyxError> {
    let x = Tensor::from([2, 4]);
    assert_eq!(x.sum_all(), 6);
    Ok(())
}

#[test]
fn sum_2() -> Result<(), ZyxError> {
    let x = Tensor::from([[4, 1, 3], [5, 2, 3], [6, 5, 7]]);
    let x0 = x.sum([-1])?;
    let x1 = x.sum([-2])?;
    let x2 = x.sum_all();
    assert_eq!(x0, [8, 10, 18]);
    assert_eq!(x1, [15, 8, 13]);
    assert_eq!(x2, [36]);
    Ok(())
}

#[test]
fn sum_3() -> Result<(), ZyxError> {
    let x = Tensor::from([[2, 4, 3], [1, 5, 1]]);
    assert_eq!(x.sum([0])?, [3, 9, 4]);
    assert_eq!(x.sum([1])?, [9, 7]);
    assert_eq!(x.sum_all(), 16);
    Ok(())
}

#[test]
fn sum_4() -> Result<(), ZyxError> {
    let x = Tensor::from([[4, 1, 3], [5, 2, 3], [6, 5, 7]]);
    let x0 = x.relu().sum([-1])?;
    assert_eq!(x0, [8, 10, 18]);
    Ok(())
}

#[test]
fn sum_5() -> Result<(), ZyxError> {
    let mut x = Tensor::from([[2, 3, 1], [2, 4, 1]]);
    x = x.sum_all();
    debug_assert_eq!(x, [13i32]);
    Ok(())
}

#[test]
fn sum_6() -> Result<(), ZyxError> {
    let x = Tensor::from([[4i32, 1, 3], [5, 2, 3], [6, 5, 7]]);
    let x = x.sum([0])?;
    let x = x.sum([0])?;
    debug_assert_eq!(x, [36i32]);
    Ok(())
}

#[test]
fn max_1() -> Result<(), ZyxError> {
    let x = Tensor::from([[4, 1, 3], [5, 2, 3], [6, 5, 7]]);
    let x0 = x.max([-1])?;
    let x1 = x.max([-2])?;
    let x2 = x.max_all();
    assert_eq!(x0, [4, 5, 7]);
    assert_eq!(x1, [6, 5, 7]);
    assert_eq!(x2, [7]);
    Ok(())
}

#[test]
fn sum_large_2d() -> Result<(), ZyxError> {
    // Large tensor that would benefit from loop splitting optimization
    let x = Tensor::from([
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        [11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
        [21, 22, 23, 24, 25, 26, 27, 28, 29, 30],
        [31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
        [41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
        [51, 52, 53, 54, 55, 56, 57, 58, 59, 60],
        [61, 62, 63, 64, 65, 66, 67, 68, 69, 70],
        [71, 72, 73, 74, 75, 76, 77, 78, 79, 80],
    ]);
    let x0 = x.sum([-1])?;
    assert_eq!(x0, [55, 155, 255, 355, 455, 555, 655, 755]);
    Ok(())
}

#[test]
fn max_large_3d() -> Result<(), ZyxError> {
    let x = Tensor::from([
        [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
        [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
        [[19, 20, 21], [22, 23, 24], [25, 26, 27]],
    ]);
    let x0 = x.max([-1])?;
    assert_eq!(x0, [[3, 6, 9], [12, 15, 18], [21, 24, 27]]);
    Ok(())
}

#[test]
fn sum_warp_reduce_1() -> Result<(), ZyxError> {
    // Large tensor to trigger warp reduce (needs >= 256 elements in reduce dim)
    let x = Tensor::from(vec![1i32; 512]);
    let x = x.reshape([8, 64])?;
    let x0 = x.sum([-1])?;
    assert_eq!(x0, vec![64i32; 8]);
    Ok(())
}

#[test]
fn sum_warp_reduce_large() -> Result<(), ZyxError> {
    // Very large tensor to trigger warp reduce (reduce dim >= 256)
    let x = Tensor::from(vec![1i32; 2048]);
    let x = x.reshape([8, 256])?;
    let x0 = x.sum([-1])?;
    assert_eq!(x0, vec![256i32; 8]);
    Ok(())
}

#[test]
fn sum_reduce_32k() -> Result<(), ZyxError> {
    // 32K element reduce to verify tiled_reduce is selected for large workloads
    // Shape [8, 4096]: 4096 = 32K / 8
    let x = Tensor::from(vec![1i32; 32768]);
    let x = x.reshape([8, 4096])?;
    let x0 = x.sum([-1])?;
    assert_eq!(x0, vec![4096i32; 8]);
    Ok(())
}

#[test]
fn sum_reduce_last_dim_3d() -> Result<(), ZyxError> {
    // 3D tensor with large reduction over last dimension
    // Shape: [4, 8, 512] -> reduce last dim -> [4, 8]
    let data: Vec<i32> = (0..(4 * 1024 * 512)).map(|i| (i % 10) as i32).collect();
    let x = Tensor::from(data);
    let x = x.reshape([4, 1024, 512])?;
    let x0 = x.sum([-1])?;

    // Calculate expected values per row
    let expected: [[i32; 1024]; 4] = {
        let mut arr = [[0i32; 1024]; 4];
        for i in 0..4 {
            for j in 0..1024 {
                let row = i * 1024 + j;
                let start = row * 512;
                arr[i][j] = (start..start + 512).map(|k| (k % 10) as i32).sum();
            }
        }
        arr
    };
    //println!("{x0:?}");
    if x0 != expected {
        panic!("test failure");
    }
    Ok(())
}