burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(split)]
mod tests {
    use super::*;
    use alloc::vec::Vec;
    use burn_tensor::{Int, Shape, Tensor, TensorData};

    #[test]
    fn test_split_evenly_divisible() {
        let device = Default::default();
        let tensors =
            TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], &device);

        let split_tensors = tensors.split(2, 0);
        assert_eq!(split_tensors.len(), 3);

        let expected = vec![
            TensorData::from([[0, 1], [2, 3]]),
            TensorData::from([[4, 5], [6, 7]]),
            TensorData::from([[8, 9], [10, 11]]),
        ];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_split_not_evenly_divisible() {
        let device = Default::default();
        let tensors = TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], &device);

        let split_tensors = tensors.split(2, 0);
        assert_eq!(split_tensors.len(), 3);

        let expected = vec![
            TensorData::from([[0, 1], [2, 3]]),
            TensorData::from([[4, 5], [6, 7]]),
            TensorData::from([[8, 9]]),
        ];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_split_along_dim1() {
        let device = Default::default();
        let tensors = TestTensor::<2>::from_data([[0, 1, 2], [3, 4, 5]], &device);

        let split_tensors = tensors.split(2, 1);
        assert_eq!(split_tensors.len(), 2);

        let expected = vec![
            TensorData::from([[0, 1], [3, 4]]),
            TensorData::from([[2], [5]]),
        ];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_split_split_size_larger_than_tensor_size() {
        let device = Default::default();
        let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device);

        let split_tensors = tensors.split(10, 0);
        assert_eq!(split_tensors.len(), 1);

        let expected = vec![TensorData::from([0, 1, 2, 3, 4])];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_split_with_zero_split_size_zero_tensor_size() {
        let device = Default::default();
        let empty_array: [i32; 0] = [];
        let tensors = TestTensor::<1>::from_data(empty_array, &device);

        let split_tensors = tensors.split(0, 0);
        assert_eq!(split_tensors.len(), 0);
    }

    #[test]
    fn test_split_zero_sized_tensor() {
        let device = Default::default();
        let empty_array: [i32; 0] = [];
        let tensors = TestTensor::<1>::from_data(empty_array, &device);

        let split_tensors = tensors.split(1, 0);
        assert_eq!(split_tensors.len(), 0);
    }

    #[test]
    #[should_panic(
        expected = "split_size must be greater than 0 unless the tensor size along the dimension is 0."
    )]
    fn test_split_with_zero_split_size_non_zero_tensor() {
        let device = Default::default();
        let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device);

        let _split_tensors = tensors.split(0, 0);
    }

    #[test]
    #[should_panic(expected = "Given dimension is greater than or equal to the tensor rank.")]
    fn test_split_invalid_dim() {
        let device = Default::default();
        let tensors = TestTensor::<1>::from_data([0, 1, 2], &device);

        let _split_tensors = tensors.split(1, 2);
    }

    #[test]
    fn test_split_3d_tensor_along_dim0() {
        let device = Default::default();
        let tensors = TestTensor::<3>::from_data(
            [
                [[0, 1], [2, 3]],
                [[4, 5], [6, 7]],
                [[8, 9], [10, 11]],
                [[12, 13], [14, 15]],
            ],
            &device,
        );

        let split_tensors = tensors.split(2, 0);
        assert_eq!(split_tensors.len(), 2);

        let expected = vec![
            TensorData::from([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]),
            TensorData::from([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]),
        ];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_split_3d_tensor_along_dim1() {
        let device = Default::default();
        let tensors = TestTensor::<3>::from_data(
            [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]],
            &device,
        );

        let split_tensors = tensors.split(2, 1);
        assert_eq!(split_tensors.len(), 2);

        let expected = vec![
            TensorData::from([[[0, 1], [2, 3]], [[6, 7], [8, 9]]]),
            TensorData::from([[[4, 5]], [[10, 11]]]),
        ];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    fn test_split_with_sizes() {
        let device = Default::default();
        let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device);

        let split_tensors = tensors.split_with_sizes(vec![2, 3, 1], 0);
        assert_eq!(split_tensors.len(), 3);

        let expected = vec![
            TensorData::from([0, 1]),
            TensorData::from([2, 3, 4]),
            TensorData::from([5]),
        ];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }

    #[test]
    #[should_panic(
        expected = "The sum of split_sizes must equal the tensor size along the specified dimension."
    )]
    fn test_split_with_sizes_invalid_sum() {
        let device = Default::default();
        let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device);

        let _split_tensors = tensors.split_with_sizes(vec![2, 2, 1], 0);
    }

    #[test]
    fn test_split_with_sizes_zero_length() {
        let device = Default::default();
        let tensors = TestTensor::<1>::from_data([0, 1, 2], &device);

        let split_tensors = tensors.split_with_sizes(vec![0, 1, 2], 0);
        assert_eq!(split_tensors.len(), 2);

        let expected = vec![TensorData::from([0]), TensorData::from([1, 2])];

        for (index, tensor) in split_tensors.iter().enumerate() {
            tensor.to_data().assert_eq(&expected[index], false);
        }
    }
}