Skip to main content

burn_tensor/tensor/api/
bool.rs

1use crate::{Bool, Cast, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::{vec, vec::Vec};
3use burn_backend::get_device_settings;
4
5use crate::try_read_sync;
6
7/// The part of the tensor to keep when creating a triangular mask.
8enum TriPart {
9    /// Upper triangular part.
10    Upper,
11
12    /// Lower triangular part.
13    Lower,
14
15    /// Diagonal part.
16    Diagonal,
17}
18
19impl<B, const D: usize> Tensor<B, D, Bool>
20where
21    B: Backend,
22{
23    /// Create a boolean tensor from data on the given device.
24    ///
25    /// # Arguments
26    ///
27    /// * `data` - The tensor data.
28    /// * `device` - The device on which the tensor will be allocated.
29    ///
30    /// # Returns
31    ///
32    /// A boolean tensor.
33    ///
34    /// # Example
35    ///
36    /// ```rust
37    /// use burn_tensor::backend::Backend;
38    /// use burn_tensor::{Tensor, Bool};
39    ///
40    /// fn example<B: Backend>() {
41    ///     let device = Default::default();
42    ///     let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
43    ///     println!("{tensor}");
44    /// }
45    /// ```
46    pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
47        Self::from_data(data, device)
48    }
49
50    /// Convert the bool tensor into an int tensor.
51    ///
52    /// # Returns
53    ///
54    /// An integer tensor where `true` is converted to `1` and `false` to `0`.
55    ///
56    /// # Example
57    ///
58    /// ```rust
59    /// use burn_tensor::backend::Backend;
60    /// use burn_tensor::{Tensor, Bool};
61    ///
62    /// fn example<B: Backend>() {
63    ///     let device = Default::default();
64    ///     let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
65    ///     let int_tensor = bool_tensor.int();
66    ///     println!("{int_tensor}"); // [1, 0, 1]
67    /// }
68    /// ```
69    pub fn int(self) -> Tensor<B, D, Int> {
70        let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
71        Tensor::new(B::bool_into_int(self.primitive, out_dtype))
72    }
73
74    /// Convert the bool tensor into a float tensor.
75    ///
76    /// # Returns
77    ///
78    /// A float tensor where `true` is converted to `1.0` and `false` to `0.0`.
79    ///
80    /// # Example
81    ///
82    /// ```rust
83    /// use burn_tensor::backend::Backend;
84    /// use burn_tensor::{Tensor, Bool};
85    ///
86    /// fn example<B: Backend>() {
87    ///     let device = Default::default();
88    ///     let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
89    ///     let float_tensor = bool_tensor.float();
90    ///     println!("{float_tensor}"); // [1.0, 0.0, 1.0]
91    /// }
92    /// ```
93    pub fn float(self) -> Tensor<B, D> {
94        let out_dtype = get_device_settings::<B>(&self.device()).float_dtype;
95        Tensor::new(TensorPrimitive::Float(B::bool_into_float(
96            self.primitive,
97            out_dtype,
98        )))
99    }
100
101    /// Converts a bool tensor to the specified data type.
102    ///
103    /// Supports casting to [`IntDType`](crate::IntDType) (producing an int tensor)
104    /// or [`FloatDType`](crate::FloatDType) (producing a float tensor).
105    ///
106    /// # Example
107    ///
108    /// ```rust
109    /// use burn_tensor::backend::Backend;
110    /// use burn_tensor::{Tensor, Bool, IntDType, FloatDType};
111    ///
112    /// fn example<B: Backend>() {
113    ///     let device = Default::default();
114    ///     let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
115    ///
116    ///     // Cast to int
117    ///     let int_tensor = bool_tensor.clone().cast(IntDType::I64);
118    ///
119    ///     // Cast to float
120    ///     let float_tensor = bool_tensor.cast(FloatDType::F32);
121    /// }
122    /// ```
123    #[must_use]
124    pub fn cast<T: Cast<B, Bool>>(self, dtype: T) -> Tensor<B, D, T::OutputKind> {
125        Tensor::new(T::cast(self.primitive, dtype))
126    }
127
128    /// Inverses boolean values.
129    ///
130    /// # Example
131    ///
132    /// ```rust
133    /// use burn_tensor::backend::Backend;
134    /// use burn_tensor::{Tensor, Bool};
135    ///
136    /// fn example<B: Backend>() {
137    ///     let device = Default::default();
138    ///     let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
139    ///     let inverted = tensor.bool_not();
140    ///     println!("{inverted}"); // [[false, true], [true, false]]
141    /// }
142    /// ```
143    pub fn bool_not(self) -> Self {
144        Tensor::new(B::bool_not(self.primitive))
145    }
146
147    /// Performs logical and (`&&`) on two boolean tensors.
148    ///
149    /// # Arguments
150    ///
151    /// * `rhs` - The right-hand side tensor for the AND operation.
152    ///
153    /// # Returns
154    ///
155    /// A boolean tensor where each element is the result of `self[i] && rhs[i]`.
156    ///
157    /// # Example
158    ///
159    /// ```rust
160    /// use burn_tensor::backend::Backend;
161    /// use burn_tensor::{Tensor, Bool};
162    ///
163    /// fn example<B: Backend>() {
164    ///     let device = Default::default();
165    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
166    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
167    ///     let result = a.bool_and(b);
168    ///     println!("{result}"); // [[true, false], [false, false]]
169    /// }
170    /// ```
171    pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
172        Tensor::new(B::bool_and(self.primitive, rhs.primitive))
173    }
174
175    /// Performs logical or (`||`) on two boolean tensors.
176    ///
177    /// # Arguments
178    ///
179    /// * `rhs` - The right-hand side tensor for the OR operation.
180    ///
181    /// # Returns
182    ///
183    /// A boolean tensor where each element is the result of `self[i] || rhs[i]`.
184    ///
185    /// # Example
186    ///
187    /// ```rust
188    /// use burn_tensor::backend::Backend;
189    /// use burn_tensor::{Tensor, Bool};
190    ///
191    /// fn example<B: Backend>() {
192    ///     let device = Default::default();
193    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
194    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
195    ///     let result = a.bool_or(b);
196    ///     println!("{result}"); // [[true, true], [true, false]]
197    /// }
198    /// ```
199    pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
200        Tensor::new(B::bool_or(self.primitive, rhs.primitive))
201    }
202
203    /// Performs logical xor (`^`) on two boolean tensors.
204    ///
205    /// # Arguments
206    ///
207    /// * `rhs` - The right-hand side tensor for the XOR operation.
208    ///
209    /// # Returns
210    ///
211    /// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`.
212    /// Returns `true` when exactly one of the operands is `true`.
213    ///
214    /// # Example
215    ///
216    /// ```rust
217    /// use burn_tensor::backend::Backend;
218    /// use burn_tensor::{Tensor, Bool};
219    ///
220    /// fn example<B: Backend>() {
221    ///     let device = Default::default();
222    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
223    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
224    ///     let result = a.bool_xor(b);
225    ///     println!("{result}"); // [[false, true], [true, false]]
226    /// }
227    /// ```
228    pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
229        Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
230    }
231
232    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
233    ///
234    /// # Returns
235    ///
236    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
237    /// the non-zero elements in that dimension.
238    ///
239    /// # Example
240    ///
241    /// ```rust
242    /// use burn_tensor::backend::Backend;
243    /// use burn_tensor::{Tensor, Bool};
244    ///
245    /// fn example<B: Backend>() {
246    ///     let device = Default::default();
247    ///     let tensor = Tensor::<B, 2, Bool>::from_bool(
248    ///         [[true, false, true], [false, true, false], [false, true, false]].into(),
249    ///         &device,
250    ///     );
251    ///     let indices = tensor.nonzero();
252    ///     println!("{}", indices[0]); // [0, 0, 1, 2]
253    ///     println!("{}", indices[1]); // [0, 2, 1, 1]
254    /// }
255    /// ```
256    pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
257        try_read_sync(self.nonzero_async())
258            .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
259    }
260
261    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
262    ///
263    /// # Returns
264    ///
265    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
266    /// the non-zero elements in that dimension.
267    pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
268        let indices = self.argwhere_async().await;
269
270        if indices.shape().num_elements() == 0 {
271            // Return empty vec when all elements are zero
272            return vec![];
273        }
274
275        let dims = indices.shape();
276        indices
277            .chunk(dims[1], 1)
278            .into_iter()
279            .map(|t| t.reshape(Shape::new([dims[0]])))
280            .collect()
281    }
282
283    /// Compute the indices of the elements that are true, grouped by element.
284    ///
285    /// # Returns
286    ///
287    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
288    /// result contains the indices of a non-zero element.
289    ///
290    /// # Example
291    ///
292    /// ```rust
293    /// use burn_tensor::backend::Backend;
294    /// use burn_tensor::{Tensor, Bool};
295    ///
296    /// fn example<B: Backend>() {
297    ///     let device = Default::default();
298    ///     let tensor = Tensor::<B, 2, Bool>::from_bool(
299    ///         [[true, false, true], [false, true, false], [false, true, false]].into(),
300    ///         &device,
301    ///     );
302    ///     let indices = tensor.argwhere();
303    ///     println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]]
304    /// }
305    /// ```
306    pub fn argwhere(self) -> Tensor<B, 2, Int> {
307        try_read_sync(self.argwhere_async())
308            .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
309    }
310
311    /// Compute the indices of the elements that are true, grouped by element.
312    ///
313    /// # Returns
314    ///
315    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
316    /// result contains the indices of a non-zero element.
317    pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
318        let out_dtype = get_device_settings::<B>(&self.device()).int_dtype;
319        Tensor::new(B::bool_argwhere(self.primitive, out_dtype).await)
320    }
321
322    /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
323    /// fill the specified area with a value.
324    fn tri_mask<S: Into<Shape>>(
325        shape: S,
326        tri_part: TriPart,
327        offset: i64,
328        device: &B::Device,
329    ) -> Self {
330        let shape: Shape = shape.into();
331        let height = shape[D - 2];
332        let width = shape[D - 1];
333
334        // Generate row and column index tensors.
335        let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
336        let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
337
338        // Prepare shapes for broadcasting.
339        let mut row_shape = [1; D];
340        row_shape[D - 2] = height;
341        let mut col_shape = [1; D];
342        col_shape[D - 1] = width;
343
344        // Reshape for broadcasting.
345        let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
346        let col_broadcast = col_indices.reshape(Shape::new(col_shape));
347
348        // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
349        let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
350
351        // Select the appropriate comparison function based on `tri_part`.
352        let compare = match tri_part {
353            TriPart::Upper => Tensor::greater_elem,
354            TriPart::Lower => Tensor::lower_elem,
355            TriPart::Diagonal => Tensor::not_equal_elem,
356        };
357
358        // Generate and return the mask by applying the comparison to the matrix.
359        compare(matrix, 0).unsqueeze()
360    }
361
362    /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
363    /// area with a value.
364    ///
365    /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
366    ///
367    /// # Arguments
368    ///
369    /// * `shape`: The shape of the matrix.
370    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
371    ///   towards the upper triangle.
372    /// * `device`: The device on which the tensor will be allocated.
373    ///
374    /// # Returns
375    ///
376    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
377    /// upper triangle taking into account the specified `offset`. All other elements are `true`.
378    ///
379    /// # Example
380    /// ```rust
381    /// use burn_tensor::backend::Backend;
382    /// use burn_tensor::{Tensor, Bool};
383    ///
384    /// fn example<B: Backend>() {
385    ///   let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
386    ///   println!("{mask}");
387    ///   // [[false, false, false],
388    ///   //  [true, false, false],
389    ///   //  [true, true, false]]
390    /// }
391    /// ```
392    pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
393        Self::tri_mask(shape, TriPart::Upper, offset, device)
394    }
395
396    /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
397    /// area with a value.
398    ///
399    /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
400    ///
401    /// # Arguments
402    ///
403    /// * `shape`: The shape of the matrix.
404    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
405    ///   towards the lower triangle.
406    /// * `device`: The device on which the tensor will be allocated.
407    ///
408    /// # Returns
409    ///
410    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
411    /// lower triangle taking into account the specified `offset`. All other elements are `true`.
412    ///
413    /// # Example
414    /// ```rust
415    /// use burn_tensor::backend::Backend;
416    /// use burn_tensor::{Tensor, Bool};
417    ///
418    /// fn example<B: Backend>() {
419    ///   let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
420    ///   println!("{mask}");
421    ///   // [[false, true, true],
422    ///   //  [false, false, true],
423    ///   //  [false, false, false]]
424    /// }
425    /// ```
426    pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
427        Self::tri_mask(shape, TriPart::Lower, offset, device)
428    }
429
430    /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
431    /// area with a value.
432    ///
433    /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
434    ///
435    /// # Arguments
436    ///
437    /// * `shape`: The shape of the matrix.
438    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
439    ///   towards the upper triangle.
440    /// * `device`: The device on which the tensor will be allocated.
441    ///
442    /// # Returns
443    ///
444    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
445    /// diagonal. All other elements are `true`.
446    ///
447    /// # Example
448    /// ```rust
449    /// use burn_tensor::backend::Backend;
450    /// use burn_tensor::{Tensor, Bool};
451    ///
452    /// fn example<B: Backend>() {
453    ///   let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
454    ///   println!("{mask}");
455    ///   // [[false, true, true],
456    ///   //  [true, false, true],
457    ///   //  [true, true, false]]
458    /// }
459    /// ```
460    pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
461        Self::tri_mask(shape, TriPart::Diagonal, offset, device)
462    }
463}