burn_tensor/tensor/api/
bool.rs

1use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::{vec, vec::Vec};
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8    /// Upper triangular part.
9    Upper,
10
11    /// Lower triangular part.
12    Lower,
13
14    /// Diagonal part.
15    Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20    B: Backend,
21{
22    /// Create a boolean tensor from data on the given device.
23    pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
24        Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
25    }
26
27    /// Convert the bool tensor into an int tensor.
28    pub fn int(self) -> Tensor<B, D, Int> {
29        Tensor::new(B::bool_into_int(self.primitive))
30    }
31
32    /// Convert the bool tensor into a float tensor.
33    pub fn float(self) -> Tensor<B, D> {
34        Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
35    }
36
37    /// Inverses boolean values.
38    pub fn bool_not(self) -> Self {
39        Tensor::new(B::bool_not(self.primitive))
40    }
41
42    /// Performs logical and (`&&`) on two boolean tensors
43    pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
44        Tensor::new(B::bool_and(self.primitive, rhs.primitive))
45    }
46
47    /// Performs logical or (`||`) on two boolean tensors
48    pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
49        Tensor::new(B::bool_or(self.primitive, rhs.primitive))
50    }
51
52    /// Performs logical xor (`^`) on two boolean tensors
53    pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
54        Tensor::new(B::bool_xor(self.primitive, rhs.primitive))
55    }
56
57    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
58    ///
59    /// # Returns
60    ///
61    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
62    /// the non-zero elements in that dimension.
63    pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
64        try_read_sync(self.nonzero_async())
65            .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
66    }
67
68    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).
69    ///
70    /// # Returns
71    ///
72    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
73    /// the non-zero elements in that dimension.
74    pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
75        let indices = self.argwhere_async().await;
76
77        if indices.shape().num_elements() == 0 {
78            // Return empty vec when all elements are zero
79            return vec![];
80        }
81
82        let dims = indices.shape().dims;
83        indices
84            .chunk(dims[1], 1)
85            .into_iter()
86            .map(|t| t.reshape(Shape::new([dims[0]])))
87            .collect()
88    }
89
90    /// Compute the indices of the elements that are true, grouped by element.
91    ///
92    /// # Returns
93    ///
94    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
95    /// result contains the indices of a non-zero element.
96    pub fn argwhere(self) -> Tensor<B, 2, Int> {
97        try_read_sync(self.argwhere_async())
98            .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
99    }
100
101    /// Compute the indices of the elements that are true, grouped by element.
102    ///
103    /// # Returns
104    ///
105    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
106    /// result contains the indices of a non-zero element.
107    pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
108        Tensor::new(B::bool_argwhere(self.primitive).await)
109    }
110
111    /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
112    /// fill the specified area with a value.
113    fn tri_mask<S: Into<Shape>>(
114        shape: S,
115        tri_part: TriPart,
116        offset: i64,
117        device: &B::Device,
118    ) -> Self {
119        let shape: Shape = shape.into();
120        let height = shape[D - 2];
121        let width = shape[D - 1];
122
123        // Generate row and column index tensors.
124        let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
125        let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
126
127        // Prepare shapes for broadcasting.
128        let mut row_shape = [1; D];
129        row_shape[D - 2] = height;
130        let mut col_shape = [1; D];
131        col_shape[D - 1] = width;
132
133        // Reshape for broadcasting.
134        let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
135        let col_broadcast = col_indices.reshape(Shape::new(col_shape));
136
137        // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
138        let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
139
140        // Select the appropriate comparison function based on `tri_part`.
141        let compare = match tri_part {
142            TriPart::Upper => Tensor::greater_elem,
143            TriPart::Lower => Tensor::lower_elem,
144            TriPart::Diagonal => Tensor::not_equal_elem,
145        };
146
147        // Generate and return the mask by applying the comparison to the matrix.
148        compare(matrix, 0).unsqueeze()
149    }
150
151    /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
152    /// area with a value.
153    ///
154    /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
155    ///
156    /// # Arguments
157    ///
158    /// * `shape`: The shape of the matrix.
159    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
160    ///   towards the upper triangle.
161    /// * `device`: The device on which the tensor will be allocated.
162    ///
163    /// # Returns
164    ///
165    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
166    /// upper triangle taking into account the specified `offset`. All other elements are `true`.
167    ///
168    /// # Example
169    /// ```rust
170    /// use burn_tensor::backend::Backend;
171    /// use burn_tensor::{Tensor, Bool};
172    ///
173    /// fn example<B: Backend>() {
174    ///   let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
175    ///   println!("{mask}");
176    ///   // [[false, false, false],
177    ///   //  [true, false, false],
178    ///   //  [true, true, false]]
179    /// }
180    /// ```
181    pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
182        Self::tri_mask(shape, TriPart::Upper, offset, device)
183    }
184
185    /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
186    /// area with a value.
187    ///
188    /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
189    ///
190    /// # Arguments
191    ///
192    /// * `shape`: The shape of the matrix.
193    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
194    ///   towards the lower triangle.
195    /// * `device`: The device on which the tensor will be allocated.
196    ///
197    /// # Returns
198    ///
199    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
200    /// lower triangle taking into account the specified `offset`. All other elements are `true`.
201    ///
202    /// # Example
203    /// ```rust
204    /// use burn_tensor::backend::Backend;
205    /// use burn_tensor::{Tensor, Bool};
206    ///
207    /// fn example<B: Backend>() {
208    ///   let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
209    ///   println!("{mask}");
210    ///   // [[false, true, true],
211    ///   //  [false, false, true],
212    ///   //  [false, false, false]]
213    /// }
214    /// ```
215    pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
216        Self::tri_mask(shape, TriPart::Lower, offset, device)
217    }
218
219    /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
220    /// area with a value.
221    ///
222    /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
223    ///
224    /// # Arguments
225    ///
226    /// * `shape`: The shape of the matrix.
227    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
228    ///   towards the upper triangle.
229    /// * `device`: The device on which the tensor will be allocated.
230    ///
231    /// # Returns
232    ///
233    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
234    /// diagonal. All other elements are `true`.
235    ///
236    /// # Example
237    /// ```rust
238    /// use burn_tensor::backend::Backend;
239    /// use burn_tensor::{Tensor, Bool};
240    ///
241    /// fn example<B: Backend>() {
242    ///   let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
243    ///   println!("{mask}");
244    ///   // [[false, true, true],
245    ///   //  [true, false, true],
246    ///   //  [true, true, false]]
247    /// }
248    /// ```
249    pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
250        Self::tri_mask(shape, TriPart::Diagonal, offset, device)
251    }
252}