burn_tensor/tensor/api/bool.rs
1use crate::{Bool, Cast, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::{vec, vec::Vec};
3use burn_backend::get_device_settings;
4
5use crate::try_read_sync;
6
7/// The part of the tensor to keep when creating a triangular mask.
8enum TriPart {
9 /// Upper triangular part.
10 Upper,
11
12 /// Lower triangular part.
13 Lower,
14
15 /// Diagonal part.
16 Diagonal,
17}
18
19impl<B, const D: usize> Tensor<B, D, Bool>
20where
21 B: Backend,
22{
23 /// Create a boolean tensor from data on the given device.
24 ///
25 /// # Arguments
26 ///
27 /// * `data` - The tensor data.
28 /// * `device` - The device on which the tensor will be allocated.
29 ///
30 /// # Returns
31 ///
32 /// A boolean tensor.
33 ///
34 /// # Example
35 ///
36 /// ```rust
37 /// use burn_tensor::backend::Backend;
38 /// use burn_tensor::{Tensor, Bool};
39 ///
40 /// fn example<B: Backend>() {
41 /// let device = Default::default();
42 /// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
43 /// println!("{tensor}");
44 /// }
45 /// ```
46 pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
47 Self::from_data(data, device)
48 }
49
50 /// Convert the bool tensor into an int tensor.
51 ///
52 /// # Returns
53 ///
54 /// An integer tensor where `true` is converted to `1` and `false` to `0`.
55 ///
56 /// # Example
57 ///
58 /// ```rust
59 /// use burn_tensor::backend::Backend;
60 /// use burn_tensor::{Tensor, Bool};
61 ///
62 /// fn example<B: Backend>() {
63 /// let device = Default::default();
64 /// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
65 /// let int_tensor = bool_tensor.int();
66 /// println!("{int_tensor}"); // [1, 0, 1]
67 /// }
68 /// ```
69 pub fn int(self) -> Tensor<B, D, Int> {
70 let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
71 Tensor::new(B::bool_into_int(self.primitive, out_dtype))
72 }
73
74 /// Convert the bool tensor into a float tensor.
75 ///
76 /// # Returns
77 ///
78 /// A float tensor where `true` is converted to `1.0` and `false` to `0.0`.
79 ///
80 /// # Example
81 ///
82 /// ```rust
83 /// use burn_tensor::backend::Backend;
84 /// use burn_tensor::{Tensor, Bool};
85 ///
86 /// fn example<B: Backend>() {
87 /// let device = Default::default();
88 /// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
89 /// let float_tensor = bool_tensor.float();
90 /// println!("{float_tensor}"); // [1.0, 0.0, 1.0]
91 /// }
92 /// ```
93 pub fn float(self) -> Tensor<B, D> {
94 let out_dtype = get_device_settings::<B>(&self.device()).float_dtype;
95 Tensor::new(TensorPrimitive::Float(B::bool_into_float(
96 self.primitive,
97 out_dtype,
98 )))
99 }
100
101 /// Converts a bool tensor to the specified data type.
102 ///
103 /// Supports casting to [`IntDType`](crate::IntDType) (producing an int tensor)
104 /// or [`FloatDType`](crate::FloatDType) (producing a float tensor).
105 ///
106 /// # Example
107 ///
108 /// ```rust
109 /// use burn_tensor::backend::Backend;
110 /// use burn_tensor::{Tensor, Bool, IntDType, FloatDType};
111 ///
112 /// fn example<B: Backend>() {
113 /// let device = Default::default();
114 /// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
115 ///
116 /// // Cast to int
117 /// let int_tensor = bool_tensor.clone().cast(IntDType::I64);
118 ///
119 /// // Cast to float
120 /// let float_tensor = bool_tensor.cast(FloatDType::F32);
121 /// }
122 /// ```
123 #[must_use]
124 pub fn cast<T: Cast<B, Bool>>(self, dtype: T) -> Tensor<B, D, T::OutputKind> {
125 Tensor::new(T::cast(self.primitive, dtype))
126 }
127
128 /// Inverses boolean values.
129 ///
130 /// # Example
131 ///
132 /// ```rust
133 /// use burn_tensor::backend::Backend;
134 /// use burn_tensor::{Tensor, Bool};
135 ///
136 /// fn example<B: Backend>() {
137 /// let device = Default::default();
138 /// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
139 /// let inverted = tensor.bool_not();
140 /// println!("{inverted}"); // [[false, true], [true, false]]
141 /// }
142 /// ```
143 pub fn bool_not(self) -> Self {
144 Tensor::new(B::bool_not(self.primitive))
145 }
146
147 /// Performs logical and (`&&`) on two boolean tensors.
148 ///
149 /// # Arguments
150 ///
151 /// * `rhs` - The right-hand side tensor for the AND operation.
152 ///
153 /// # Returns
154 ///
155 /// A boolean tensor where each element is the result of `self[i] && rhs[i]`.
156 ///
157 /// # Example
158 ///
159 /// ```rust
160 /// use burn_tensor::backend::Backend;
161 /// use burn_tensor::{Tensor, Bool};
162 ///
163 /// fn example<B: Backend>() {
164 /// let device = Default::default();
165 /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
166 /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
167 /// let result = a.bool_and(b);
168 /// println!("{result}"); // [[true, false], [false, false]]
169 /// }
170 /// ```
171 pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
172 Tensor::new(B::bool_and(self.primitive, rhs.primitive))
173 }
174
175 /// Performs logical or (`||`) on two boolean tensors.
176 ///
177 /// # Arguments
178 ///
179 /// * `rhs` - The right-hand side tensor for the OR operation.
180 ///
181 /// # Returns
182 ///
183 /// A boolean tensor where each element is the result of `self[i] || rhs[i]`.
184 ///
185 /// # Example
186 ///
187 /// ```rust
188 /// use burn_tensor::backend::Backend;
189 /// use burn_tensor::{Tensor, Bool};
190 ///
191 /// fn example<B: Backend>() {
192 /// let device = Default::default();
193 /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
194 /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
195 /// let result = a.bool_or(b);
196 /// println!("{result}"); // [[true, true], [true, false]]
197 /// }
198 /// ```
199 pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
200 Tensor::new(B::bool_or(self.primitive, rhs.primitive))
201 }
202
203 /// Performs logical xor (`^`) on two boolean tensors.
204 ///
205 /// # Arguments
206 ///
207 /// * `rhs` - The right-hand side tensor for the XOR operation.
208 ///
209 /// # Returns
210 ///
211 /// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`.
212 /// Returns `true` when exactly one of the operands is `true`.
213 ///
214 /// # Example
215 ///
216 /// ```rust
217 /// use burn_tensor::backend::Backend;
218 /// use burn_tensor::{Tensor, Bool};
219 ///
220 /// fn example<B: Backend>() {
221 /// let device = Default::default();
222 /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
223 /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
224 /// let result = a.bool_xor(b);
225 /// println!("{result}"); // [[false, true], [true, false]]
226 /// }
227 /// ```
228 pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
229 Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
230 }
231
232 /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
233 ///
234 /// # Returns
235 ///
236 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
237 /// the non-zero elements in that dimension.
238 ///
239 /// # Example
240 ///
241 /// ```rust
242 /// use burn_tensor::backend::Backend;
243 /// use burn_tensor::{Tensor, Bool};
244 ///
245 /// fn example<B: Backend>() {
246 /// let device = Default::default();
247 /// let tensor = Tensor::<B, 2, Bool>::from_bool(
248 /// [[true, false, true], [false, true, false], [false, true, false]].into(),
249 /// &device,
250 /// );
251 /// let indices = tensor.nonzero();
252 /// println!("{}", indices[0]); // [0, 0, 1, 2]
253 /// println!("{}", indices[1]); // [0, 2, 1, 1]
254 /// }
255 /// ```
256 pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
257 try_read_sync(self.nonzero_async())
258 .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
259 }
260
261 /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
262 ///
263 /// # Returns
264 ///
265 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
266 /// the non-zero elements in that dimension.
267 pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
268 let indices = self.argwhere_async().await;
269
270 if indices.shape().num_elements() == 0 {
271 // Return empty vec when all elements are zero
272 return vec![];
273 }
274
275 let dims = indices.shape();
276 indices
277 .chunk(dims[1], 1)
278 .into_iter()
279 .map(|t| t.reshape(Shape::new([dims[0]])))
280 .collect()
281 }
282
283 /// Compute the indices of the elements that are true, grouped by element.
284 ///
285 /// # Returns
286 ///
287 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
288 /// result contains the indices of a non-zero element.
289 ///
290 /// # Example
291 ///
292 /// ```rust
293 /// use burn_tensor::backend::Backend;
294 /// use burn_tensor::{Tensor, Bool};
295 ///
296 /// fn example<B: Backend>() {
297 /// let device = Default::default();
298 /// let tensor = Tensor::<B, 2, Bool>::from_bool(
299 /// [[true, false, true], [false, true, false], [false, true, false]].into(),
300 /// &device,
301 /// );
302 /// let indices = tensor.argwhere();
303 /// println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]]
304 /// }
305 /// ```
306 pub fn argwhere(self) -> Tensor<B, 2, Int> {
307 try_read_sync(self.argwhere_async())
308 .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
309 }
310
311 /// Compute the indices of the elements that are true, grouped by element.
312 ///
313 /// # Returns
314 ///
315 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
316 /// result contains the indices of a non-zero element.
317 pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
318 let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
319 Tensor::new(B::bool_argwhere(self.primitive, out_dtype).await)
320 }
321
322 /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
323 /// fill the specified area with a value.
324 fn tri_mask<S: Into<Shape>>(
325 shape: S,
326 tri_part: TriPart,
327 offset: i64,
328 device: &B::Device,
329 ) -> Self {
330 let shape: Shape = shape.into();
331 let height = shape[D - 2];
332 let width = shape[D - 1];
333
334 // Generate row and column index tensors.
335 let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
336 let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
337
338 // Prepare shapes for broadcasting.
339 let mut row_shape = [1; D];
340 row_shape[D - 2] = height;
341 let mut col_shape = [1; D];
342 col_shape[D - 1] = width;
343
344 // Reshape for broadcasting.
345 let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
346 let col_broadcast = col_indices.reshape(Shape::new(col_shape));
347
348 // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
349 let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
350
351 // Select the appropriate comparison function based on `tri_part`.
352 let compare = match tri_part {
353 TriPart::Upper => Tensor::greater_elem,
354 TriPart::Lower => Tensor::lower_elem,
355 TriPart::Diagonal => Tensor::not_equal_elem,
356 };
357
358 // Generate and return the mask by applying the comparison to the matrix.
359 compare(matrix, 0).unsqueeze()
360 }
361
362 /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
363 /// area with a value.
364 ///
365 /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
366 ///
367 /// # Arguments
368 ///
369 /// * `shape`: The shape of the matrix.
370 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
371 /// towards the upper triangle.
372 /// * `device`: The device on which the tensor will be allocated.
373 ///
374 /// # Returns
375 ///
376 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
377 /// upper triangle taking into account the specified `offset`. All other elements are `true`.
378 ///
379 /// # Example
380 /// ```rust
381 /// use burn_tensor::backend::Backend;
382 /// use burn_tensor::{Tensor, Bool};
383 ///
384 /// fn example<B: Backend>() {
385 /// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
386 /// println!("{mask}");
387 /// // [[false, false, false],
388 /// // [true, false, false],
389 /// // [true, true, false]]
390 /// }
391 /// ```
392 pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
393 Self::tri_mask(shape, TriPart::Upper, offset, device)
394 }
395
396 /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
397 /// area with a value.
398 ///
399 /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
400 ///
401 /// # Arguments
402 ///
403 /// * `shape`: The shape of the matrix.
404 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
405 /// towards the lower triangle.
406 /// * `device`: The device on which the tensor will be allocated.
407 ///
408 /// # Returns
409 ///
410 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
411 /// lower triangle taking into account the specified `offset`. All other elements are `true`.
412 ///
413 /// # Example
414 /// ```rust
415 /// use burn_tensor::backend::Backend;
416 /// use burn_tensor::{Tensor, Bool};
417 ///
418 /// fn example<B: Backend>() {
419 /// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
420 /// println!("{mask}");
421 /// // [[false, true, true],
422 /// // [false, false, true],
423 /// // [false, false, false]]
424 /// }
425 /// ```
426 pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
427 Self::tri_mask(shape, TriPart::Lower, offset, device)
428 }
429
430 /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
431 /// area with a value.
432 ///
433 /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
434 ///
435 /// # Arguments
436 ///
437 /// * `shape`: The shape of the matrix.
438 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
439 /// towards the upper triangle.
440 /// * `device`: The device on which the tensor will be allocated.
441 ///
442 /// # Returns
443 ///
444 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
445 /// diagonal. All other elements are `true`.
446 ///
447 /// # Example
448 /// ```rust
449 /// use burn_tensor::backend::Backend;
450 /// use burn_tensor::{Tensor, Bool};
451 ///
452 /// fn example<B: Backend>() {
453 /// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
454 /// println!("{mask}");
455 /// // [[false, true, true],
456 /// // [true, false, true],
457 /// // [true, true, false]]
458 /// }
459 /// ```
460 pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
461 Self::tri_mask(shape, TriPart::Diagonal, offset, device)
462 }
463}