burn_tensor/tensor/api/bool.rs
1use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::{vec, vec::Vec};
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8 /// Upper triangular part.
9 Upper,
10
11 /// Lower triangular part.
12 Lower,
13
14 /// Diagonal part.
15 Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20 B: Backend,
21{
22 /// Create a boolean tensor from data on the given device.
23 ///
24 /// # Arguments
25 ///
26 /// * `data` - The tensor data.
27 /// * `device` - The device on which the tensor will be allocated.
28 ///
29 /// # Returns
30 ///
31 /// A boolean tensor.
32 ///
33 /// # Example
34 ///
35 /// ```rust
36 /// use burn_tensor::backend::Backend;
37 /// use burn_tensor::{Tensor, Bool};
38 ///
39 /// fn example<B: Backend>() {
40 /// let device = Default::default();
41 /// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
42 /// println!("{tensor}");
43 /// }
44 /// ```
45 pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
46 Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
47 }
48
49 /// Convert the bool tensor into an int tensor.
50 ///
51 /// # Returns
52 ///
53 /// An integer tensor where `true` is converted to `1` and `false` to `0`.
54 ///
55 /// # Example
56 ///
57 /// ```rust
58 /// use burn_tensor::backend::Backend;
59 /// use burn_tensor::{Tensor, Bool};
60 ///
61 /// fn example<B: Backend>() {
62 /// let device = Default::default();
63 /// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
64 /// let int_tensor = bool_tensor.int();
65 /// println!("{int_tensor}"); // [1, 0, 1]
66 /// }
67 /// ```
68 pub fn int(self) -> Tensor<B, D, Int> {
69 Tensor::new(B::bool_into_int(self.primitive))
70 }
71
72 /// Convert the bool tensor into a float tensor.
73 ///
74 /// # Returns
75 ///
76 /// A float tensor where `true` is converted to `1.0` and `false` to `0.0`.
77 ///
78 /// # Example
79 ///
80 /// ```rust
81 /// use burn_tensor::backend::Backend;
82 /// use burn_tensor::{Tensor, Bool};
83 ///
84 /// fn example<B: Backend>() {
85 /// let device = Default::default();
86 /// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
87 /// let float_tensor = bool_tensor.float();
88 /// println!("{float_tensor}"); // [1.0, 0.0, 1.0]
89 /// }
90 /// ```
91 pub fn float(self) -> Tensor<B, D> {
92 Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
93 }
94
95 /// Inverses boolean values.
96 ///
97 /// # Example
98 ///
99 /// ```rust
100 /// use burn_tensor::backend::Backend;
101 /// use burn_tensor::{Tensor, Bool};
102 ///
103 /// fn example<B: Backend>() {
104 /// let device = Default::default();
105 /// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
106 /// let inverted = tensor.bool_not();
107 /// println!("{inverted}"); // [[false, true], [true, false]]
108 /// }
109 /// ```
110 pub fn bool_not(self) -> Self {
111 Tensor::new(B::bool_not(self.primitive))
112 }
113
114 /// Performs logical and (`&&`) on two boolean tensors.
115 ///
116 /// # Arguments
117 ///
118 /// * `rhs` - The right-hand side tensor for the AND operation.
119 ///
120 /// # Returns
121 ///
122 /// A boolean tensor where each element is the result of `self[i] && rhs[i]`.
123 ///
124 /// # Example
125 ///
126 /// ```rust
127 /// use burn_tensor::backend::Backend;
128 /// use burn_tensor::{Tensor, Bool};
129 ///
130 /// fn example<B: Backend>() {
131 /// let device = Default::default();
132 /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
133 /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
134 /// let result = a.bool_and(b);
135 /// println!("{result}"); // [[true, false], [false, false]]
136 /// }
137 /// ```
138 pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
139 Tensor::new(B::bool_and(self.primitive, rhs.primitive))
140 }
141
142 /// Performs logical or (`||`) on two boolean tensors.
143 ///
144 /// # Arguments
145 ///
146 /// * `rhs` - The right-hand side tensor for the OR operation.
147 ///
148 /// # Returns
149 ///
150 /// A boolean tensor where each element is the result of `self[i] || rhs[i]`.
151 ///
152 /// # Example
153 ///
154 /// ```rust
155 /// use burn_tensor::backend::Backend;
156 /// use burn_tensor::{Tensor, Bool};
157 ///
158 /// fn example<B: Backend>() {
159 /// let device = Default::default();
160 /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
161 /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
162 /// let result = a.bool_or(b);
163 /// println!("{result}"); // [[true, true], [true, false]]
164 /// }
165 /// ```
166 pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
167 Tensor::new(B::bool_or(self.primitive, rhs.primitive))
168 }
169
170 /// Performs logical xor (`^`) on two boolean tensors.
171 ///
172 /// # Arguments
173 ///
174 /// * `rhs` - The right-hand side tensor for the XOR operation.
175 ///
176 /// # Returns
177 ///
178 /// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`.
179 /// Returns `true` when exactly one of the operands is `true`.
180 ///
181 /// # Example
182 ///
183 /// ```rust
184 /// use burn_tensor::backend::Backend;
185 /// use burn_tensor::{Tensor, Bool};
186 ///
187 /// fn example<B: Backend>() {
188 /// let device = Default::default();
189 /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
190 /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
191 /// let result = a.bool_xor(b);
192 /// println!("{result}"); // [[false, true], [true, false]]
193 /// }
194 /// ```
195 pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
196 Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
197 }
198
199 /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
200 ///
201 /// # Returns
202 ///
203 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
204 /// the non-zero elements in that dimension.
205 ///
206 /// # Example
207 ///
208 /// ```rust
209 /// use burn_tensor::backend::Backend;
210 /// use burn_tensor::{Tensor, Bool};
211 ///
212 /// fn example<B: Backend>() {
213 /// let device = Default::default();
214 /// let tensor = Tensor::<B, 2, Bool>::from_bool(
215 /// [[true, false, true], [false, true, false], [false, true, false]].into(),
216 /// &device,
217 /// );
218 /// let indices = tensor.nonzero();
219 /// println!("{}", indices[0]); // [0, 0, 1, 2]
220 /// println!("{}", indices[1]); // [0, 2, 1, 1]
221 /// }
222 /// ```
223 pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
224 try_read_sync(self.nonzero_async())
225 .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
226 }
227
228 /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
229 ///
230 /// # Returns
231 ///
232 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
233 /// the non-zero elements in that dimension.
234 pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
235 let indices = self.argwhere_async().await;
236
237 if indices.shape().num_elements() == 0 {
238 // Return empty vec when all elements are zero
239 return vec![];
240 }
241
242 let dims = indices.shape().dims;
243 indices
244 .chunk(dims[1], 1)
245 .into_iter()
246 .map(|t| t.reshape(Shape::new([dims[0]])))
247 .collect()
248 }
249
250 /// Compute the indices of the elements that are true, grouped by element.
251 ///
252 /// # Returns
253 ///
254 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
255 /// result contains the indices of a non-zero element.
256 ///
257 /// # Example
258 ///
259 /// ```rust
260 /// use burn_tensor::backend::Backend;
261 /// use burn_tensor::{Tensor, Bool};
262 ///
263 /// fn example<B: Backend>() {
264 /// let device = Default::default();
265 /// let tensor = Tensor::<B, 2, Bool>::from_bool(
266 /// [[true, false, true], [false, true, false], [false, true, false]].into(),
267 /// &device,
268 /// );
269 /// let indices = tensor.argwhere();
270 /// println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]]
271 /// }
272 /// ```
273 pub fn argwhere(self) -> Tensor<B, 2, Int> {
274 try_read_sync(self.argwhere_async())
275 .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
276 }
277
278 /// Compute the indices of the elements that are true, grouped by element.
279 ///
280 /// # Returns
281 ///
282 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
283 /// result contains the indices of a non-zero element.
284 pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
285 Tensor::new(B::bool_argwhere(self.primitive).await)
286 }
287
288 /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
289 /// fill the specified area with a value.
290 fn tri_mask<S: Into<Shape>>(
291 shape: S,
292 tri_part: TriPart,
293 offset: i64,
294 device: &B::Device,
295 ) -> Self {
296 let shape: Shape = shape.into();
297 let height = shape[D - 2];
298 let width = shape[D - 1];
299
300 // Generate row and column index tensors.
301 let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
302 let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
303
304 // Prepare shapes for broadcasting.
305 let mut row_shape = [1; D];
306 row_shape[D - 2] = height;
307 let mut col_shape = [1; D];
308 col_shape[D - 1] = width;
309
310 // Reshape for broadcasting.
311 let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
312 let col_broadcast = col_indices.reshape(Shape::new(col_shape));
313
314 // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
315 let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
316
317 // Select the appropriate comparison function based on `tri_part`.
318 let compare = match tri_part {
319 TriPart::Upper => Tensor::greater_elem,
320 TriPart::Lower => Tensor::lower_elem,
321 TriPart::Diagonal => Tensor::not_equal_elem,
322 };
323
324 // Generate and return the mask by applying the comparison to the matrix.
325 compare(matrix, 0).unsqueeze()
326 }
327
328 /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
329 /// area with a value.
330 ///
331 /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
332 ///
333 /// # Arguments
334 ///
335 /// * `shape`: The shape of the matrix.
336 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
337 /// towards the upper triangle.
338 /// * `device`: The device on which the tensor will be allocated.
339 ///
340 /// # Returns
341 ///
342 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
343 /// upper triangle taking into account the specified `offset`. All other elements are `true`.
344 ///
345 /// # Example
346 /// ```rust
347 /// use burn_tensor::backend::Backend;
348 /// use burn_tensor::{Tensor, Bool};
349 ///
350 /// fn example<B: Backend>() {
351 /// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
352 /// println!("{mask}");
353 /// // [[false, false, false],
354 /// // [true, false, false],
355 /// // [true, true, false]]
356 /// }
357 /// ```
358 pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
359 Self::tri_mask(shape, TriPart::Upper, offset, device)
360 }
361
362 /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
363 /// area with a value.
364 ///
365 /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
366 ///
367 /// # Arguments
368 ///
369 /// * `shape`: The shape of the matrix.
370 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
371 /// towards the lower triangle.
372 /// * `device`: The device on which the tensor will be allocated.
373 ///
374 /// # Returns
375 ///
376 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
377 /// lower triangle taking into account the specified `offset`. All other elements are `true`.
378 ///
379 /// # Example
380 /// ```rust
381 /// use burn_tensor::backend::Backend;
382 /// use burn_tensor::{Tensor, Bool};
383 ///
384 /// fn example<B: Backend>() {
385 /// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
386 /// println!("{mask}");
387 /// // [[false, true, true],
388 /// // [false, false, true],
389 /// // [false, false, false]]
390 /// }
391 /// ```
392 pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
393 Self::tri_mask(shape, TriPart::Lower, offset, device)
394 }
395
396 /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
397 /// area with a value.
398 ///
399 /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
400 ///
401 /// # Arguments
402 ///
403 /// * `shape`: The shape of the matrix.
404 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
405 /// towards the upper triangle.
406 /// * `device`: The device on which the tensor will be allocated.
407 ///
408 /// # Returns
409 ///
410 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
411 /// diagonal. All other elements are `true`.
412 ///
413 /// # Example
414 /// ```rust
415 /// use burn_tensor::backend::Backend;
416 /// use burn_tensor::{Tensor, Bool};
417 ///
418 /// fn example<B: Backend>() {
419 /// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
420 /// println!("{mask}");
421 /// // [[false, true, true],
422 /// // [true, false, true],
423 /// // [true, true, false]]
424 /// }
425 /// ```
426 pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
427 Self::tri_mask(shape, TriPart::Diagonal, offset, device)
428 }
429}