burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(chunk)]
mod tests {
    use super::*;
    use alloc::vec::Vec;
    use burn_tensor::{Int, Shape, Tensor, TensorData};

    #[test]
    fn test_chunk_evenly_divisible() {
        let tensors: Vec<Tensor<TestBackend, 1, Int>> =
            Tensor::arange(0..12, &Default::default()).chunk(6, 0);
        assert_eq!(tensors.len(), 6);

        let expected = vec![
            TensorData::from([0, 1]),
            TensorData::from([2, 3]),
            TensorData::from([4, 5]),
            TensorData::from([6, 7]),
            TensorData::from([8, 9]),
            TensorData::from([10, 11]),
        ];

        for (index, tensor) in tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_chunk_not_evenly_divisible() {
        let tensors: Vec<Tensor<TestBackend, 1, Int>> =
            Tensor::arange(0..11, &Default::default()).chunk(6, 0);
        assert_eq!(tensors.len(), 6);

        let expected = vec![
            TensorData::from([0, 1]),
            TensorData::from([2, 3]),
            TensorData::from([4, 5]),
            TensorData::from([6, 7]),
            TensorData::from([8, 9]),
            TensorData::from([10]),
        ];

        for (index, tensor) in tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_chunk_not_evenly_divisible_remains_several() {
        let tensors: Vec<Tensor<TestBackend, 1, Int>> =
            Tensor::arange(0..100, &Default::default()).chunk(8, 0);
        assert_eq!(tensors.len(), 8);

        let expected = [13, 13, 13, 13, 13, 13, 13, 9];

        for (index, tensor) in tensors.iter().enumerate() {
            assert_eq!(tensor.shape().dims[0], expected[index]);
        }
    }

    #[test]
    fn test_chunk_not_divisible() {
        let tensors: Vec<Tensor<TestBackend, 1, Int>> =
            Tensor::arange(0..6, &Default::default()).chunk(7, 0);
        assert_eq!(tensors.len(), 6);

        let expected = vec![
            TensorData::from([0]),
            TensorData::from([1]),
            TensorData::from([2]),
            TensorData::from([3]),
            TensorData::from([4]),
            TensorData::from([5]),
        ];

        for (index, tensor) in tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_chunk_multi_dimension() {
        let tensors: Vec<Tensor<TestBackend, 2, Int>> =
            Tensor::from_data(TensorData::from([[0, 1, 2, 3]]), &Default::default()).chunk(2, 1);
        assert_eq!(tensors.len(), 2);

        let expected = vec![TensorData::from([[0, 1]]), TensorData::from([[2, 3]])];

        for (index, tensor) in tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    #[should_panic]
    fn test_invalid_dim() {
        let tensors: Vec<Tensor<TestBackend, 1, Int>> =
            Tensor::arange(0..12, &Default::default()).chunk(6, 1);
    }
}