yscv-tensor 0.1.8

SIMD-accelerated tensor library with f32/f16/bf16 support and 80+ operations
Documentation
use crate::{Tensor, TensorError};

#[test]
fn creates_scalar_tensor() {
    let scalar = Tensor::scalar(3.25);
    assert_eq!(scalar.shape(), &[]);
    assert_eq!(scalar.rank(), 0);
    assert_eq!(scalar.data(), &[3.25]);
}

#[test]
fn from_vec_accepts_valid_shape_and_computes_strides() {
    let tensor = Tensor::from_vec(vec![2, 3, 4], vec![0.0; 24]).unwrap();
    assert_eq!(tensor.shape(), &[2, 3, 4]);
    assert_eq!(tensor.strides(), &[12, 4, 1]);
    assert_eq!(tensor.len(), 24);
}

#[test]
fn from_vec_rejects_size_mismatch() {
    let err = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0]).unwrap_err();
    assert_eq!(
        err,
        TensorError::SizeMismatch {
            shape: vec![2, 2],
            data_len: 3
        }
    );
}

#[test]
fn get_and_set_use_multidimensional_indexing() {
    let mut tensor = Tensor::zeros(vec![2, 2]).unwrap();
    tensor.set(&[1, 0], 4.5).unwrap();
    assert_eq!(tensor.get(&[1, 0]).unwrap(), 4.5);
    assert_eq!(tensor.get(&[0, 0]).unwrap(), 0.0);
}

#[test]
fn get_rejects_invalid_index_rank() {
    let tensor = Tensor::zeros(vec![2, 2]).unwrap();
    let err = tensor.get(&[1]).unwrap_err();
    assert_eq!(
        err,
        TensorError::InvalidIndexRank {
            expected: 2,
            got: 1
        }
    );
}

#[test]
fn set_rejects_out_of_bounds_index() {
    let mut tensor = Tensor::zeros(vec![2, 2]).unwrap();
    let err = tensor.set(&[2, 0], 1.0).unwrap_err();
    assert_eq!(
        err,
        TensorError::IndexOutOfBounds {
            axis: 0,
            index: 2,
            dim: 2
        }
    );
}

#[test]
fn reshape_changes_metadata_not_data_order() {
    let tensor = Tensor::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
    let reshaped = tensor.reshape(vec![3, 2]).unwrap();

    assert_eq!(reshaped.shape(), &[3, 2]);
    assert_eq!(reshaped.strides(), &[2, 1]);
    assert_eq!(reshaped.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}

#[test]
fn reshape_rejects_size_mismatch() {
    let tensor = Tensor::zeros(vec![2, 3]).unwrap();
    let err = tensor.reshape(vec![5]).unwrap_err();
    assert_eq!(
        err,
        TensorError::ReshapeSizeMismatch {
            from: vec![2, 3],
            to: vec![5]
        }
    );
}

#[test]
fn rejects_shape_that_overflows_total_size() {
    let err = Tensor::zeros(vec![usize::MAX, 2]).unwrap_err();
    assert_eq!(
        err,
        TensorError::SizeOverflow {
            shape: vec![usize::MAX, 2]
        }
    );
}