burn_tensor/tensor/api/bool.rs
1use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};
2use alloc::vec::Vec;
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8 /// Upper triangular part.
9 Upper,
10
11 /// Lower triangular part.
12 Lower,
13
14 /// Diagonal part.
15 Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20 B: Backend,
21{
22 /// Create a boolean tensor from data on the given device.
23 pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
24 Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))
25 }
26
27 /// Convert the bool tensor into an int tensor.
28 pub fn int(self) -> Tensor<B, D, Int> {
29 Tensor::new(B::bool_into_int(self.primitive))
30 }
31
32 /// Convert the bool tensor into an float tensor.
33 pub fn float(self) -> Tensor<B, D> {
34 Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
35 }
36
37 /// Inverses boolean values.
38 pub fn bool_not(self) -> Self {
39 Tensor::new(B::bool_not(self.primitive))
40 }
41
42 /// Performs logical and (`&&`) on two boolean tensors
43 pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
44 Tensor::new(B::bool_and(self.primitive, rhs.primitive))
45 }
46
47 /// Performs logical or (`||`) on two boolean tensors
48 pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {
49 Tensor::new(B::bool_or(self.primitive, rhs.primitive))
50 }
51
52 /// Compute the indices of the elements that are non-zero.
53 ///
54 /// # Returns
55 ///
56 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
57 /// the non-zero elements in that dimension.
58 pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
59 try_read_sync(self.nonzero_async())
60 .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
61 }
62
63 /// Compute the indices of the elements that are non-zero.
64 ///
65 /// # Returns
66 ///
67 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
68 /// the non-zero elements in that dimension.
69 pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
70 B::bool_nonzero(self.primitive)
71 .await
72 .into_iter()
73 .map(Tensor::new)
74 .collect()
75 }
76
77 /// Compute the indices of the elements that are true, grouped by element.
78 ///
79 /// # Returns
80 ///
81 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
82 /// result contains the indices of a non-zero element.
83 pub fn argwhere(self) -> Tensor<B, 2, Int> {
84 try_read_sync(self.argwhere_async())
85 .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
86 }
87
88 /// Compute the indices of the elements that are true, grouped by element.
89 ///
90 /// # Returns
91 ///
92 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
93 /// result contains the indices of a non-zero element.
94 pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
95 Tensor::new(B::bool_argwhere(self.primitive).await)
96 }
97
98 /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
99 /// fill the specified area with a value.
100 fn tri_mask<S: Into<Shape>>(
101 shape: S,
102 tri_part: TriPart,
103 offset: i64,
104 device: &B::Device,
105 ) -> Self {
106 let shape: Shape = shape.into();
107 let height = shape.dims[D - 2];
108 let width = shape.dims[D - 1];
109
110 // Generate row and column index tensors.
111 let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
112 let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
113
114 // Prepare shapes for broadcasting.
115 let mut row_shape = [1; D];
116 row_shape[D - 2] = height;
117 let mut col_shape = [1; D];
118 col_shape[D - 1] = width;
119
120 // Reshape for broadcasting.
121 let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
122 let col_broadcast = col_indices.reshape(Shape::new(col_shape));
123
124 // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
125 let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
126
127 // Select the appropriate comparison function based on `tri_part`.
128 let compare = match tri_part {
129 TriPart::Upper => Tensor::greater_elem,
130 TriPart::Lower => Tensor::lower_elem,
131 TriPart::Diagonal => Tensor::not_equal_elem,
132 };
133
134 // Generate and return the mask by applying the comparison to the matrix.
135 compare(matrix, 0).unsqueeze()
136 }
137
138 /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
139 /// area with a value.
140 ///
141 /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
142 ///
143 /// # Arguments
144 ///
145 /// * `shape`: The shape of the matrix.
146 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
147 /// towards the upper triangle.
148 /// * `device`: The device on which the tensor will be allocated.
149 ///
150 /// # Returns
151 ///
152 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
153 /// upper triangle taking into account the specified `offset`. All other elements are `true`.
154 ///
155 /// # Example
156 /// ```rust
157 /// use burn_tensor::backend::Backend;
158 /// use burn_tensor::{Tensor, Bool};
159 ///
160 /// fn example<B: Backend>() {
161 /// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
162 /// println!("{mask}");
163 /// // [[false, false, false],
164 /// // [true, false, false],
165 /// // [true, true, false]]
166 /// }
167 /// ```
168 pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
169 Self::tri_mask(shape, TriPart::Upper, offset, device)
170 }
171
172 /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
173 /// area with a value.
174 ///
175 /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
176 ///
177 /// # Arguments
178 ///
179 /// * `shape`: The shape of the matrix.
180 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
181 /// towards the lower triangle.
182 /// * `device`: The device on which the tensor will be allocated.
183 ///
184 /// # Returns
185 ///
186 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
187 /// lower triangle taking into account the specified `offset`. All other elements are `true`.
188 ///
189 /// # Example
190 /// ```rust
191 /// use burn_tensor::backend::Backend;
192 /// use burn_tensor::{Tensor, Bool};
193 ///
194 /// fn example<B: Backend>() {
195 /// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
196 /// println!("{mask}");
197 /// // [[false, true, true],
198 /// // [false, false, true],
199 /// // [false, false, false]]
200 /// }
201 /// ```
202 pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
203 Self::tri_mask(shape, TriPart::Lower, offset, device)
204 }
205
206 /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
207 /// area with a value.
208 ///
209 /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
210 ///
211 /// # Arguments
212 ///
213 /// * `shape`: The shape of the matrix.
214 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
215 /// towards the upper triangle.
216 /// * `device`: The device on which the tensor will be allocated.
217 ///
218 /// # Returns
219 ///
220 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
221 /// diagonal. All other elements are `true`.
222 ///
223 /// # Example
224 /// ```rust
225 /// use burn_tensor::backend::Backend;
226 /// use burn_tensor::{Tensor, Bool};
227 ///
228 /// fn example<B: Backend>() {
229 /// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
230 /// println!("{mask}");
231 /// // [[false, true, true],
232 /// // [true, false, true],
233 /// // [true, true, false]]
234 /// }
235 /// ```
236 pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
237 Self::tri_mask(shape, TriPart::Diagonal, offset, device)
238 }
239}