burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(argwhere_nonzero)]
mod tests {
    use super::*;
    use alloc::vec::Vec;
    use burn_tensor::{Tensor, TensorData};

    #[test]
    fn test_argwhere_1d() {
        let tensor = TestTensorBool::<1>::from([false, true, false, true, true]);
        let output = tensor.argwhere();

        output
            .into_data()
            .assert_eq(&TensorData::from([[1], [3], [4]]), false);
    }

    #[test]
    fn test_argwhere_2d() {
        let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]);
        let output = tensor.argwhere();

        output
            .into_data()
            .assert_eq(&TensorData::from([[1, 1], [2, 0], [2, 1]]), false);
    }

    #[test]
    fn test_argwhere_3d() {
        let tensor = TestTensorBool::<3>::from([
            [[false, false, false], [false, true, false]],
            [[true, false, true], [true, true, false]],
        ]);
        let output = tensor.argwhere();

        output.into_data().assert_eq(
            &TensorData::from([[0, 1, 1], [1, 0, 0], [1, 0, 2], [1, 1, 0], [1, 1, 1]]),
            false,
        );
    }

    #[test]
    fn test_nonzero_1d() {
        let tensor = TestTensorBool::<1>::from([false, true, false, true, true]);
        let data_actual = tensor
            .nonzero()
            .into_iter()
            .map(|t| t.into_data())
            .collect::<Vec<_>>();

        assert_eq!(data_actual.len(), 1);
        data_actual[0].assert_eq(&TensorData::from([1, 3, 4]), false);
    }

    #[test]
    fn test_nonzero_2d() {
        // 2-D tensor
        let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]);
        let data_actual = tensor
            .nonzero()
            .into_iter()
            .map(|t| t.into_data())
            .collect::<Vec<_>>();
        let data_expected = vec![TensorData::from([1, 2, 2]), TensorData::from([1, 0, 1])];

        assert_eq!(data_actual.len(), 2);
        for (idx, actual) in data_actual.iter().enumerate() {
            actual.assert_eq(&data_expected[idx], false)
        }
    }

    #[test]
    fn test_nonzero_3d() {
        // 3-D tensor
        let tensor = TestTensorBool::<3>::from([
            [[false, false, false], [false, true, false]],
            [[true, false, true], [true, true, false]],
        ]);
        let data_actual = tensor
            .nonzero()
            .into_iter()
            .map(|t| t.into_data())
            .collect::<Vec<_>>();
        let data_expected = vec![
            TensorData::from([0, 1, 1, 1, 1]),
            TensorData::from([1, 0, 0, 1, 1]),
            TensorData::from([1, 0, 2, 0, 1]),
        ];

        assert_eq!(data_actual.len(), 3);
        for (idx, actual) in data_actual.iter().enumerate() {
            actual.assert_eq(&data_expected[idx], false)
        }
    }

    #[test]
    fn test_nonzero_empty() {
        let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);
        let output = tensor.nonzero();

        assert_eq!(output.len(), 0);
    }
}