burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(iter_dim)]
mod test {
    use super::*;
    use burn_tensor::{backend::Backend, Int, Tensor, TensorData};

    #[test]
    fn test_1d_iter_last_item() {
        let data = [1, 2, 3, 4];
        let device = Default::default();
        let tensor = TestTensorInt::<1>::from_ints(data, &device);
        assert_eq!(
            TestTensorInt::<1>::from_ints([4], &device).into_data(),
            tensor.iter_dim(0).last().unwrap().into_data()
        )
    }

    #[test]
    #[should_panic]
    fn test_too_high_dimension() {
        TestTensor::<1>::zeros([10], &Default::default()).iter_dim(1);
    }

    #[test]
    fn test_transposed() {
        let data = [
            [1., 2., 3., 1., 2.],
            [4., 5., 6., 1., 2.],
            [7., 8., 9., 1., 2.],
        ];
        let tensor = TestTensor::<2>::from_floats(data, &Default::default());
        let lhs = tensor.clone().slice([1..2, 0..5]);
        let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap();
        assert_eq!(
            lhs.into_data()
                .as_slice::<<TestBackend as Backend>::FloatElem>()
                .unwrap(),
            rhs.into_data()
                .as_slice::<<TestBackend as Backend>::FloatElem>()
                .unwrap()
        );
    }

    fn test_iteration_over_low_dim() {
        let data = [[
            [1., 2., 3., 1., 2.],
            [4., 5., 6., 1., 2.],
            [7., 8., 9., 1., 2.],
        ]; 5];
        let tensor = TestTensor::<3>::from_floats(data, &Default::default());
        let lhs = tensor.iter_dim(2).nth(1).unwrap();
        let rhs = TensorData::from([2., 5., 8.]);
        assert_eq!(
            lhs.into_data()
                .as_slice::<<TestBackend as Backend>::FloatElem>()
                .unwrap(),
            rhs.as_slice::<<TestBackend as Backend>::FloatElem>()
                .unwrap()
        );
    }

    #[test]
    fn test_iter_dim_double_end() {
        let input =
            TestTensorInt::<1>::arange(0..(4 * 6 * 3), &Default::default()).reshape([4, 6, 3]);
        let mut iter = input.iter_dim(1);

        let ele0 = TensorData::from([[[0, 1, 2]], [[18, 19, 20]], [[36, 37, 38]], [[54, 55, 56]]]);
        let ele1 = TensorData::from([[[3, 4, 5]], [[21, 22, 23]], [[39, 40, 41]], [[57, 58, 59]]]);
        let ele2 = TensorData::from([[[6, 7, 8]], [[24, 25, 26]], [[42, 43, 44]], [[60, 61, 62]]]);
        let ele3 = TensorData::from([
            [[9, 10, 11]],
            [[27, 28, 29]],
            [[45, 46, 47]],
            [[63, 64, 65]],
        ]);
        let ele4 = TensorData::from([
            [[12, 13, 14]],
            [[30, 31, 32]],
            [[48, 49, 50]],
            [[66, 67, 68]],
        ]);
        let ele5 = TensorData::from([
            [[15, 16, 17]],
            [[33, 34, 35]],
            [[51, 52, 53]],
            [[69, 70, 71]],
        ]);

        iter.next().unwrap().into_data().assert_eq(&ele0, false);
        iter.next_back()
            .unwrap()
            .into_data()
            .assert_eq(&ele5, false);
        iter.next_back()
            .unwrap()
            .into_data()
            .assert_eq(&ele4, false);
        iter.next().unwrap().into_data().assert_eq(&ele1, false);
        iter.next().unwrap().into_data().assert_eq(&ele2, false);
        iter.next().unwrap().into_data().assert_eq(&ele3, false);
        assert!(iter.next().is_none());
        assert!(iter.next_back().is_none());
    }

    #[test]
    fn test_iter_dim_single_element() {
        let input =
            TestTensorInt::<1>::arange(0..(4 * 1 * 3), &Default::default()).reshape([4, 1, 3]);

        let mut iter = input.clone().iter_dim(1);
        iter.next()
            .unwrap()
            .into_data()
            .assert_eq(&input.clone().into_data(), true);
        assert!(iter.next_back().is_none());
        assert!(iter.next().is_none());

        let mut iter = input.clone().iter_dim(1);
        iter.next_back()
            .unwrap()
            .into_data()
            .assert_eq(&input.clone().into_data(), true);
        assert!(iter.next().is_none());
        assert!(iter.next_back().is_none());
    }
}