burn_tensor/tensor/api/
bool.rs

1use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::{vec, vec::Vec};
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8    /// Upper triangular part.
9    Upper,
10
11    /// Lower triangular part.
12    Lower,
13
14    /// Diagonal part.
15    Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20    B: Backend,
21{
22    /// Create a boolean tensor from data on the given device.
23    ///
24    /// # Arguments
25    ///
26    /// * `data` - The tensor data.
27    /// * `device` - The device on which the tensor will be allocated.
28    ///
29    /// # Returns
30    ///
31    /// A boolean tensor.
32    ///
33    /// # Example
34    ///
35    /// ```rust
36    /// use burn_tensor::backend::Backend;
37    /// use burn_tensor::{Tensor, Bool};
38    ///
39    /// fn example<B: Backend>() {
40    ///     let device = Default::default();
41    ///     let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
42    ///     println!("{tensor}");
43    /// }
44    /// ```
45    pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
46        Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
47    }
48
49    /// Convert the bool tensor into an int tensor.
50    ///
51    /// # Returns
52    ///
53    /// An integer tensor where `true` is converted to `1` and `false` to `0`.
54    ///
55    /// # Example
56    ///
57    /// ```rust
58    /// use burn_tensor::backend::Backend;
59    /// use burn_tensor::{Tensor, Bool};
60    ///
61    /// fn example<B: Backend>() {
62    ///     let device = Default::default();
63    ///     let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
64    ///     let int_tensor = bool_tensor.int();
65    ///     println!("{int_tensor}"); // [1, 0, 1]
66    /// }
67    /// ```
68    pub fn int(self) -> Tensor<B, D, Int> {
69        Tensor::new(B::bool_into_int(self.primitive))
70    }
71
72    /// Convert the bool tensor into a float tensor.
73    ///
74    /// # Returns
75    ///
76    /// A float tensor where `true` is converted to `1.0` and `false` to `0.0`.
77    ///
78    /// # Example
79    ///
80    /// ```rust
81    /// use burn_tensor::backend::Backend;
82    /// use burn_tensor::{Tensor, Bool};
83    ///
84    /// fn example<B: Backend>() {
85    ///     let device = Default::default();
86    ///     let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);
87    ///     let float_tensor = bool_tensor.float();
88    ///     println!("{float_tensor}"); // [1.0, 0.0, 1.0]
89    /// }
90    /// ```
91    pub fn float(self) -> Tensor<B, D> {
92        Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
93    }
94
95    /// Inverses boolean values.
96    ///
97    /// # Example
98    ///
99    /// ```rust
100    /// use burn_tensor::backend::Backend;
101    /// use burn_tensor::{Tensor, Bool};
102    ///
103    /// fn example<B: Backend>() {
104    ///     let device = Default::default();
105    ///     let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);
106    ///     let inverted = tensor.bool_not();
107    ///     println!("{inverted}"); // [[false, true], [true, false]]
108    /// }
109    /// ```
110    pub fn bool_not(self) -> Self {
111        Tensor::new(B::bool_not(self.primitive))
112    }
113
114    /// Performs logical and (`&&`) on two boolean tensors.
115    ///
116    /// # Arguments
117    ///
118    /// * `rhs` - The right-hand side tensor for the AND operation.
119    ///
120    /// # Returns
121    ///
122    /// A boolean tensor where each element is the result of `self[i] && rhs[i]`.
123    ///
124    /// # Example
125    ///
126    /// ```rust
127    /// use burn_tensor::backend::Backend;
128    /// use burn_tensor::{Tensor, Bool};
129    ///
130    /// fn example<B: Backend>() {
131    ///     let device = Default::default();
132    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
133    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
134    ///     let result = a.bool_and(b);
135    ///     println!("{result}"); // [[true, false], [false, false]]
136    /// }
137    /// ```
138    pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
139        Tensor::new(B::bool_and(self.primitive, rhs.primitive))
140    }
141
142    /// Performs logical or (`||`) on two boolean tensors.
143    ///
144    /// # Arguments
145    ///
146    /// * `rhs` - The right-hand side tensor for the OR operation.
147    ///
148    /// # Returns
149    ///
150    /// A boolean tensor where each element is the result of `self[i] || rhs[i]`.
151    ///
152    /// # Example
153    ///
154    /// ```rust
155    /// use burn_tensor::backend::Backend;
156    /// use burn_tensor::{Tensor, Bool};
157    ///
158    /// fn example<B: Backend>() {
159    ///     let device = Default::default();
160    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
161    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
162    ///     let result = a.bool_or(b);
163    ///     println!("{result}"); // [[true, true], [true, false]]
164    /// }
165    /// ```
166    pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
167        Tensor::new(B::bool_or(self.primitive, rhs.primitive))
168    }
169
170    /// Performs logical xor (`^`) on two boolean tensors.
171    ///
172    /// # Arguments
173    ///
174    /// * `rhs` - The right-hand side tensor for the XOR operation.
175    ///
176    /// # Returns
177    ///
178    /// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`.
179    /// Returns `true` when exactly one of the operands is `true`.
180    ///
181    /// # Example
182    ///
183    /// ```rust
184    /// use burn_tensor::backend::Backend;
185    /// use burn_tensor::{Tensor, Bool};
186    ///
187    /// fn example<B: Backend>() {
188    ///     let device = Default::default();
189    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);
190    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);
191    ///     let result = a.bool_xor(b);
192    ///     println!("{result}"); // [[false, true], [true, false]]
193    /// }
194    /// ```
195    pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
196        Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
197    }
198
199    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
200    ///
201    /// # Returns
202    ///
203    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
204    /// the non-zero elements in that dimension.
205    ///
206    /// # Example
207    ///
208    /// ```rust
209    /// use burn_tensor::backend::Backend;
210    /// use burn_tensor::{Tensor, Bool};
211    ///
212    /// fn example<B: Backend>() {
213    ///     let device = Default::default();
214    ///     let tensor = Tensor::<B, 2, Bool>::from_bool(
215    ///         [[true, false, true], [false, true, false], [false, true, false]].into(),
216    ///         &device,
217    ///     );
218    ///     let indices = tensor.nonzero();
219    ///     println!("{}", indices[0]); // [0, 0, 1, 2]
220    ///     println!("{}", indices[1]); // [0, 2, 1, 1]
221    /// }
222    /// ```
223    pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
224        try_read_sync(self.nonzero_async())
225            .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
226    }
227
228    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
229    ///
230    /// # Returns
231    ///
232    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
233    /// the non-zero elements in that dimension.
234    pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
235        let indices = self.argwhere_async().await;
236
237        if indices.shape().num_elements() == 0 {
238            // Return empty vec when all elements are zero
239            return vec![];
240        }
241
242        let dims = indices.shape().dims;
243        indices
244            .chunk(dims[1], 1)
245            .into_iter()
246            .map(|t| t.reshape(Shape::new([dims[0]])))
247            .collect()
248    }
249
250    /// Compute the indices of the elements that are true, grouped by element.
251    ///
252    /// # Returns
253    ///
254    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
255    /// result contains the indices of a non-zero element.
256    ///
257    /// # Example
258    ///
259    /// ```rust
260    /// use burn_tensor::backend::Backend;
261    /// use burn_tensor::{Tensor, Bool};
262    ///
263    /// fn example<B: Backend>() {
264    ///     let device = Default::default();
265    ///     let tensor = Tensor::<B, 2, Bool>::from_bool(
266    ///         [[true, false, true], [false, true, false], [false, true, false]].into(),
267    ///         &device,
268    ///     );
269    ///     let indices = tensor.argwhere();
270    ///     println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]]
271    /// }
272    /// ```
273    pub fn argwhere(self) -> Tensor<B, 2, Int> {
274        try_read_sync(self.argwhere_async())
275            .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
276    }
277
278    /// Compute the indices of the elements that are true, grouped by element.
279    ///
280    /// # Returns
281    ///
282    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
283    /// result contains the indices of a non-zero element.
284    pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
285        Tensor::new(B::bool_argwhere(self.primitive).await)
286    }
287
288    /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
289    /// fill the specified area with a value.
290    fn tri_mask<S: Into<Shape>>(
291        shape: S,
292        tri_part: TriPart,
293        offset: i64,
294        device: &B::Device,
295    ) -> Self {
296        let shape: Shape = shape.into();
297        let height = shape[D - 2];
298        let width = shape[D - 1];
299
300        // Generate row and column index tensors.
301        let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
302        let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
303
304        // Prepare shapes for broadcasting.
305        let mut row_shape = [1; D];
306        row_shape[D - 2] = height;
307        let mut col_shape = [1; D];
308        col_shape[D - 1] = width;
309
310        // Reshape for broadcasting.
311        let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
312        let col_broadcast = col_indices.reshape(Shape::new(col_shape));
313
314        // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
315        let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
316
317        // Select the appropriate comparison function based on `tri_part`.
318        let compare = match tri_part {
319            TriPart::Upper => Tensor::greater_elem,
320            TriPart::Lower => Tensor::lower_elem,
321            TriPart::Diagonal => Tensor::not_equal_elem,
322        };
323
324        // Generate and return the mask by applying the comparison to the matrix.
325        compare(matrix, 0).unsqueeze()
326    }
327
328    /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
329    /// area with a value.
330    ///
331    /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
332    ///
333    /// # Arguments
334    ///
335    /// * `shape`: The shape of the matrix.
336    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
337    ///   towards the upper triangle.
338    /// * `device`: The device on which the tensor will be allocated.
339    ///
340    /// # Returns
341    ///
342    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
343    /// upper triangle taking into account the specified `offset`. All other elements are `true`.
344    ///
345    /// # Example
346    /// ```rust
347    /// use burn_tensor::backend::Backend;
348    /// use burn_tensor::{Tensor, Bool};
349    ///
350    /// fn example<B: Backend>() {
351    ///   let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
352    ///   println!("{mask}");
353    ///   // [[false, false, false],
354    ///   //  [true, false, false],
355    ///   //  [true, true, false]]
356    /// }
357    /// ```
358    pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
359        Self::tri_mask(shape, TriPart::Upper, offset, device)
360    }
361
362    /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
363    /// area with a value.
364    ///
365    /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
366    ///
367    /// # Arguments
368    ///
369    /// * `shape`: The shape of the matrix.
370    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
371    ///   towards the lower triangle.
372    /// * `device`: The device on which the tensor will be allocated.
373    ///
374    /// # Returns
375    ///
376    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
377    /// lower triangle taking into account the specified `offset`. All other elements are `true`.
378    ///
379    /// # Example
380    /// ```rust
381    /// use burn_tensor::backend::Backend;
382    /// use burn_tensor::{Tensor, Bool};
383    ///
384    /// fn example<B: Backend>() {
385    ///   let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
386    ///   println!("{mask}");
387    ///   // [[false, true, true],
388    ///   //  [false, false, true],
389    ///   //  [false, false, false]]
390    /// }
391    /// ```
392    pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
393        Self::tri_mask(shape, TriPart::Lower, offset, device)
394    }
395
396    /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
397    /// area with a value.
398    ///
399    /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
400    ///
401    /// # Arguments
402    ///
403    /// * `shape`: The shape of the matrix.
404    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
405    ///   towards the upper triangle.
406    /// * `device`: The device on which the tensor will be allocated.
407    ///
408    /// # Returns
409    ///
410    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
411    /// diagonal. All other elements are `true`.
412    ///
413    /// # Example
414    /// ```rust
415    /// use burn_tensor::backend::Backend;
416    /// use burn_tensor::{Tensor, Bool};
417    ///
418    /// fn example<B: Backend>() {
419    ///   let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
420    ///   println!("{mask}");
421    ///   // [[false, true, true],
422    ///   //  [true, false, true],
423    ///   //  [true, true, false]]
424    /// }
425    /// ```
426    pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
427        Self::tri_mask(shape, TriPart::Diagonal, offset, device)
428    }
429}