burn_tensor/tensor/api/
bool.rs

1use crate::{backend::Backend, Bool, Int, Shape, Tensor, TensorData, TensorPrimitive};
2use alloc::vec::Vec;
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8    /// Upper triangular part.
9    Upper,
10
11    /// Lower triangular part.
12    Lower,
13
14    /// Diagonal part.
15    Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20    B: Backend,
21{
22    /// Create a boolean tensor from data on the given device.
23    pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
24        Self::new(B::bool_from_data(data, device))
25    }
26
27    /// Convert the bool tensor into an int tensor.
28    pub fn int(self) -> Tensor<B, D, Int> {
29        Tensor::new(B::bool_into_int(self.primitive))
30    }
31
32    /// Convert the bool tensor into an float tensor.
33    pub fn float(self) -> Tensor<B, D> {
34        Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
35    }
36
37    /// Inverses boolean values.
38    pub fn bool_not(self) -> Self {
39        Tensor::new(B::bool_not(self.primitive))
40    }
41
42    /// Compute the indices of the elements that are non-zero.
43    ///
44    /// # Returns
45    ///
46    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
47    /// the non-zero elements in that dimension.
48    pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
49        try_read_sync(self.nonzero_async())
50            .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
51    }
52
53    /// Compute the indices of the elements that are non-zero.
54    ///
55    /// # Returns
56    ///
57    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
58    /// the non-zero elements in that dimension.
59    pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
60        B::bool_nonzero(self.primitive)
61            .await
62            .into_iter()
63            .map(Tensor::new)
64            .collect()
65    }
66
67    /// Compute the indices of the elements that are true, grouped by element.
68    ///
69    /// # Returns
70    ///
71    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
72    /// result contains the indices of a non-zero element.
73    pub fn argwhere(self) -> Tensor<B, 2, Int> {
74        try_read_sync(self.argwhere_async())
75            .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
76    }
77
78    /// Compute the indices of the elements that are true, grouped by element.
79    ///
80    /// # Returns
81    ///
82    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
83    /// result contains the indices of a non-zero element.
84    pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
85        Tensor::new(B::bool_argwhere(self.primitive).await)
86    }
87
88    /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
89    /// fill the specified area with a value.
90    fn tri_mask<S: Into<Shape>>(
91        shape: S,
92        tri_part: TriPart,
93        offset: i64,
94        device: &B::Device,
95    ) -> Self {
96        let shape: Shape = shape.into();
97        let height = shape.dims[D - 2];
98        let width = shape.dims[D - 1];
99
100        // Generate row and column index tensors.
101        let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
102        let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
103
104        // Prepare shapes for broadcasting.
105        let mut row_shape = [1; D];
106        row_shape[D - 2] = height;
107        let mut col_shape = [1; D];
108        col_shape[D - 1] = width;
109
110        // Reshape for broadcasting.
111        let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
112        let col_broadcast = col_indices.reshape(Shape::new(col_shape));
113
114        // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
115        let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
116
117        // Select the appropriate comparison function based on `tri_part`.
118        let compare = match tri_part {
119            TriPart::Upper => Tensor::greater_elem,
120            TriPart::Lower => Tensor::lower_elem,
121            TriPart::Diagonal => Tensor::not_equal_elem,
122        };
123
124        // Generate and return the mask by applying the comparison to the matrix.
125        compare(matrix, 0).unsqueeze()
126    }
127
128    /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
129    /// area with a value.
130    ///
131    /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
132    ///
133    /// # Arguments
134    ///
135    /// * `shape`: The shape of the matrix.
136    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
137    ///    towards the upper triangle.
138    /// * `device`: The device on which the tensor will be allocated.
139    ///
140    /// # Returns
141    ///
142    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
143    /// upper triangle taking into account the specified `offset`. All other elements are `true`.
144    ///
145    /// # Example
146    /// ```rust
147    /// use burn_tensor::backend::Backend;
148    /// use burn_tensor::{Tensor, Bool};
149    ///
150    /// fn example<B: Backend>() {
151    ///   let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
152    ///   println!("{mask}");
153    ///   // [[false, false, false],
154    ///   //  [true, false, false],
155    ///   //  [true, true, false]]
156    /// }
157    /// ```
158    pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
159        Self::tri_mask(shape, TriPart::Upper, offset, device)
160    }
161
162    /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
163    /// area with a value.
164    ///
165    /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
166    ///
167    /// # Arguments
168    ///
169    /// * `shape`: The shape of the matrix.
170    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
171    ///    towards the lower triangle.
172    /// * `device`: The device on which the tensor will be allocated.
173    ///
174    /// # Returns
175    ///
176    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
177    /// lower triangle taking into account the specified `offset`. All other elements are `true`.
178    ///
179    /// # Example
180    /// ```rust
181    /// use burn_tensor::backend::Backend;
182    /// use burn_tensor::{Tensor, Bool};
183    ///
184    /// fn example<B: Backend>() {
185    ///   let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
186    ///   println!("{mask}");
187    ///   // [[false, true, true],
188    ///   //  [false, false, true],
189    ///   //  [false, false, false]]
190    /// }
191    /// ```
192    pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
193        Self::tri_mask(shape, TriPart::Lower, offset, device)
194    }
195
196    /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
197    /// area with a value.
198    ///
199    /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
200    ///
201    /// # Arguments
202    ///
203    /// * `shape`: The shape of the matrix.
204    /// * `device`: The device on which the tensor will be allocated.
205    ///
206    /// # Returns
207    ///
208    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
209    /// diagonal. All other elements are `true`.
210    ///
211    /// # Example
212    /// ```rust
213    /// use burn_tensor::backend::Backend;
214    /// use burn_tensor::{Tensor, Bool};
215    ///
216    /// fn example<B: Backend>() {
217    ///   let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
218    ///   println!("{mask}");
219    ///   // [[false, true, true],
220    ///   //  [true, false, true],
221    ///   //  [true, true, false]]
222    /// }
223    /// ```
224    pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
225        Self::tri_mask(shape, TriPart::Diagonal, offset, device)
226    }
227}