1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use crate::shapes::{Dtype, Shape};
use crate::tensor::{HasErr, Tape, Tensor, TriangleTensor};

use super::TryMul;

/// Applies a 2D lower triangular mask by setting values above the diagonal to `E::default()`.
///
/// See [`TriangleTensor::lower_tri`].
pub fn lower_tri<S: Shape, E: Dtype, D: TriangleTensor<E>, T: Tape<E, D>>(
    t: Tensor<S, E, D, T>,
    diagonal: impl Into<Option<isize>>,
) -> Tensor<S, E, D, T>
where
    Tensor<S, E, D, T>: TryMul<Tensor<S, E, D>> + HasErr<Err = D::Err>,
{
    t.lower_tri(diagonal)
}

/// Applies a 2D upper triangular mask by setting values below the diagonal to `E::default()`.
///
/// See [`TriangleTensor::upper_tri`].
pub fn upper_tri<S: Shape, E: Dtype, D: TriangleTensor<E>, T: Tape<E, D>>(
    t: Tensor<S, E, D, T>,
    diagonal: impl Into<Option<isize>>,
) -> Tensor<S, E, D, T>
where
    Tensor<S, E, D, T>: TryMul<Tensor<S, E, D>> + HasErr<Err = D::Err>,
{
    t.upper_tri(diagonal)
}

impl<S: Shape, E: Dtype, D: TriangleTensor<E>, T: Tape<E, D>> Tensor<S, E, D, T>
where
    Self: TryMul<Tensor<S, E, D>> + HasErr<Err = D::Err>,
{
    /// See [lower_tri]
    pub fn try_lower_tri(
        self,
        diagonal: impl Into<Option<isize>>,
    ) -> Result<Self, <Self as HasErr>::Err> {
        let out = self
            .device
            .try_lower_tri_like(&self.shape, E::ONE, diagonal)?;
        self.try_mul(out)
    }

    /// See [lower_tri]
    pub fn lower_tri(self, diagonal: impl Into<Option<isize>>) -> Self {
        self.try_lower_tri(diagonal).unwrap()
    }

    /// See [upper_tri]
    pub fn try_upper_tri(
        self,
        diagonal: impl Into<Option<isize>>,
    ) -> Result<Self, <Self as HasErr>::Err> {
        let out = self
            .device
            .try_upper_tri_like(&self.shape, E::ONE, diagonal)?;
        self.try_mul(out)
    }

    /// See [upper_tri]
    pub fn upper_tri(self, diagonal: impl Into<Option<isize>>) -> Self {
        self.try_upper_tri(diagonal).unwrap()
    }
}

#[cfg(test)]
mod tests {
    use crate::{tensor::*, tests::*};

    #[test]
    fn test_tri() {
        let dev: TestDevice = Default::default();

        let t = dev
            .tensor(
                [[[
                    [1., 2., 3., 4., 5., 6.],
                    [1., 2., 3., 4., 5., 6.],
                    [1., 2., 3., 4., 5., 6.],
                    [1., 2., 3., 4., 5., 6.],
                    [1., 2., 3., 4., 5., 6.],
                ]; 4]; 3],
            )
            .to_dtype::<TestDtype>();
        assert_close_to_literal!(
            t.clone().lower_tri(None),
            [[[
                [1., 0., 0., 0., 0., 0.],
                [1., 2., 0., 0., 0., 0.],
                [1., 2., 3., 0., 0., 0.],
                [1., 2., 3., 4., 0., 0.],
                [1., 2., 3., 4., 5., 0.],
            ]; 4]; 3]
        );
        assert_close_to_literal!(
            t.clone().lower_tri(2),
            [[[
                [1., 2., 3., 0., 0., 0.],
                [1., 2., 3., 4., 0., 0.],
                [1., 2., 3., 4., 5., 0.],
                [1., 2., 3., 4., 5., 6.],
                [1., 2., 3., 4., 5., 6.],
            ]; 4]; 3]
        );
        assert_close_to_literal!(
            t.upper_tri(-1),
            [[[
                [1., 2., 3., 4., 5., 6.],
                [1., 2., 3., 4., 5., 6.],
                [0., 2., 3., 4., 5., 6.],
                [0., 0., 3., 4., 5., 6.],
                [0., 0., 0., 4., 5., 6.],
            ]; 4]; 3]
        );
    }
}