burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(tri_mask)]
mod tests {
    use super::*;
    use burn_tensor::{Bool, Tensor, TensorData};

    #[test]
    fn square_diag() {
        let device = Default::default();
        let data_expected = TensorData::from([
            [false, true, true],
            [true, false, true],
            [true, true, false],
        ]);
        let tensor = TestTensorBool::<2>::diag_mask([3, 3], 0, &device);
        tensor.into_data().assert_eq(&data_expected, true);
    }

    #[test]
    fn square_diag_offset() {
        let device = Default::default();
        let data_expected =
            TensorData::from([[true, false, true], [true, true, false], [true, true, true]]);
        let tensor = TestTensorBool::<2>::diag_mask([3, 3], 1, &device);
        tensor.into_data().assert_eq(&data_expected, true);
    }

    #[test]
    fn square_tri_upper() {
        let device = Default::default();
        let data_expected = TensorData::from([
            [false, false, false],
            [true, false, false],
            [true, true, false],
        ]);
        let tensor = TestTensorBool::<2>::triu_mask([3, 3], 0, &device);
        tensor.into_data().assert_eq(&data_expected, true);
    }

    #[test]
    fn square_tri_upper_offset() {
        let device = Default::default();
        let data_expected = TensorData::from([
            [true, false, false],
            [true, true, false],
            [true, true, true],
        ]);
        let tensor = TestTensorBool::<2>::triu_mask([3, 3], 1, &device);
        tensor.into_data().assert_eq(&data_expected, true);
    }

    #[test]
    fn square_tri_lower() {
        let device = Default::default();

        let data_expected = TensorData::from([
            [false, true, true],
            [false, false, true],
            [false, false, false],
        ]);
        let tensor = TestTensorBool::<2>::tril_mask([3, 3], 0, &device);
        tensor.into_data().assert_eq(&data_expected, true);
    }

    #[test]
    fn square_tri_lower_offset() {
        let device = Default::default();

        let data_expected = TensorData::from([
            [true, true, true],
            [false, true, true],
            [false, false, true],
        ]);
        let tensor = TestTensorBool::<2>::tril_mask([3, 3], -1, &device);
        tensor.into_data().assert_eq(&data_expected, true);
    }

    #[test]
    fn rect_diag() {
        let device = Default::default();
        let data_expected = TensorData::from([
            [false, true, true, true],
            [true, false, true, true],
            [true, true, false, true],
        ]);
        let tensor = TestTensorBool::<2>::diag_mask([3, 4], 0, &device);
        tensor.into_data().assert_eq(&data_expected, true);

        let data_expected = TensorData::from([
            [false, true, true],
            [true, false, true],
            [true, true, false],
            [true, true, true],
        ]);
        let tensor = TestTensorBool::<2>::diag_mask([4, 3], 0, &device);
        tensor.into_data().assert_eq(&data_expected, true);
    }
}