burn_tensor/tensor/api/bool.rs
1use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::{vec, vec::Vec};
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8 /// Upper triangular part.
9 Upper,
10
11 /// Lower triangular part.
12 Lower,
13
14 /// Diagonal part.
15 Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20 B: Backend,
21{
22 /// Create a boolean tensor from data on the given device.
23 pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
24 Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
25 }
26
27 /// Convert the bool tensor into an int tensor.
28 pub fn int(self) -> Tensor<B, D, Int> {
29 Tensor::new(B::bool_into_int(self.primitive))
30 }
31
32 /// Convert the bool tensor into a float tensor.
33 pub fn float(self) -> Tensor<B, D> {
34 Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
35 }
36
37 /// Inverses boolean values.
38 pub fn bool_not(self) -> Self {
39 Tensor::new(B::bool_not(self.primitive))
40 }
41
42 /// Performs logical and (`&&`) on two boolean tensors
43 pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
44 Tensor::new(B::bool_and(self.primitive, rhs.primitive))
45 }
46
47 /// Performs logical or (`||`) on two boolean tensors
48 pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
49 Tensor::new(B::bool_or(self.primitive, rhs.primitive))
50 }
51
52 /// Performs logical xor (`^`) on two boolean tensors
53 pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
54 Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
55 }
56
57 /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
58 ///
59 /// # Returns
60 ///
61 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
62 /// the non-zero elements in that dimension.
63 pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
64 try_read_sync(self.nonzero_async())
65 .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
66 }
67
68 /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
69 ///
70 /// # Returns
71 ///
72 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
73 /// the non-zero elements in that dimension.
74 pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
75 let indices = self.argwhere_async().await;
76
77 if indices.shape().num_elements() == 0 {
78 // Return empty vec when all elements are zero
79 return vec![];
80 }
81
82 let dims = indices.shape().dims;
83 indices
84 .chunk(dims[1], 1)
85 .into_iter()
86 .map(|t| t.reshape(Shape::new([dims[0]])))
87 .collect()
88 }
89
90 /// Compute the indices of the elements that are true, grouped by element.
91 ///
92 /// # Returns
93 ///
94 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
95 /// result contains the indices of a non-zero element.
96 pub fn argwhere(self) -> Tensor<B, 2, Int> {
97 try_read_sync(self.argwhere_async())
98 .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
99 }
100
101 /// Compute the indices of the elements that are true, grouped by element.
102 ///
103 /// # Returns
104 ///
105 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
106 /// result contains the indices of a non-zero element.
107 pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
108 Tensor::new(B::bool_argwhere(self.primitive).await)
109 }
110
111 /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
112 /// fill the specified area with a value.
113 fn tri_mask<S: Into<Shape>>(
114 shape: S,
115 tri_part: TriPart,
116 offset: i64,
117 device: &B::Device,
118 ) -> Self {
119 let shape: Shape = shape.into();
120 let height = shape[D - 2];
121 let width = shape[D - 1];
122
123 // Generate row and column index tensors.
124 let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
125 let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
126
127 // Prepare shapes for broadcasting.
128 let mut row_shape = [1; D];
129 row_shape[D - 2] = height;
130 let mut col_shape = [1; D];
131 col_shape[D - 1] = width;
132
133 // Reshape for broadcasting.
134 let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
135 let col_broadcast = col_indices.reshape(Shape::new(col_shape));
136
137 // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
138 let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
139
140 // Select the appropriate comparison function based on `tri_part`.
141 let compare = match tri_part {
142 TriPart::Upper => Tensor::greater_elem,
143 TriPart::Lower => Tensor::lower_elem,
144 TriPart::Diagonal => Tensor::not_equal_elem,
145 };
146
147 // Generate and return the mask by applying the comparison to the matrix.
148 compare(matrix, 0).unsqueeze()
149 }
150
151 /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
152 /// area with a value.
153 ///
154 /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
155 ///
156 /// # Arguments
157 ///
158 /// * `shape`: The shape of the matrix.
159 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
160 /// towards the upper triangle.
161 /// * `device`: The device on which the tensor will be allocated.
162 ///
163 /// # Returns
164 ///
165 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
166 /// upper triangle taking into account the specified `offset`. All other elements are `true`.
167 ///
168 /// # Example
169 /// ```rust
170 /// use burn_tensor::backend::Backend;
171 /// use burn_tensor::{Tensor, Bool};
172 ///
173 /// fn example<B: Backend>() {
174 /// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
175 /// println!("{mask}");
176 /// // [[false, false, false],
177 /// // [true, false, false],
178 /// // [true, true, false]]
179 /// }
180 /// ```
181 pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
182 Self::tri_mask(shape, TriPart::Upper, offset, device)
183 }
184
185 /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
186 /// area with a value.
187 ///
188 /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
189 ///
190 /// # Arguments
191 ///
192 /// * `shape`: The shape of the matrix.
193 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
194 /// towards the lower triangle.
195 /// * `device`: The device on which the tensor will be allocated.
196 ///
197 /// # Returns
198 ///
199 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
200 /// lower triangle taking into account the specified `offset`. All other elements are `true`.
201 ///
202 /// # Example
203 /// ```rust
204 /// use burn_tensor::backend::Backend;
205 /// use burn_tensor::{Tensor, Bool};
206 ///
207 /// fn example<B: Backend>() {
208 /// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
209 /// println!("{mask}");
210 /// // [[false, true, true],
211 /// // [false, false, true],
212 /// // [false, false, false]]
213 /// }
214 /// ```
215 pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
216 Self::tri_mask(shape, TriPart::Lower, offset, device)
217 }
218
219 /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
220 /// area with a value.
221 ///
222 /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
223 ///
224 /// # Arguments
225 ///
226 /// * `shape`: The shape of the matrix.
227 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
228 /// towards the upper triangle.
229 /// * `device`: The device on which the tensor will be allocated.
230 ///
231 /// # Returns
232 ///
233 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
234 /// diagonal. All other elements are `true`.
235 ///
236 /// # Example
237 /// ```rust
238 /// use burn_tensor::backend::Backend;
239 /// use burn_tensor::{Tensor, Bool};
240 ///
241 /// fn example<B: Backend>() {
242 /// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
243 /// println!("{mask}");
244 /// // [[false, true, true],
245 /// // [true, false, true],
246 /// // [true, true, false]]
247 /// }
248 /// ```
249 pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
250 Self::tri_mask(shape, TriPart::Diagonal, offset, device)
251 }
252}