burn_tensor/tensor/api/bool.rs
1use crate::{backend::Backend, Bool, Int, Shape, Tensor, TensorData, TensorPrimitive};
2use alloc::vec::Vec;
3
4use crate::try_read_sync;
5
6/// The part of the tensor to keep when creating a triangular mask.
7enum TriPart {
8 /// Upper triangular part.
9 Upper,
10
11 /// Lower triangular part.
12 Lower,
13
14 /// Diagonal part.
15 Diagonal,
16}
17
18impl<B, const D: usize> Tensor<B, D, Bool>
19where
20 B: Backend,
21{
22 /// Create a boolean tensor from data on the given device.
23 pub fn from_bool(data: TensorData, device: &B::Device) -> Self {
24 Self::new(B::bool_from_data(data, device))
25 }
26
27 /// Convert the bool tensor into an int tensor.
28 pub fn int(self) -> Tensor<B, D, Int> {
29 Tensor::new(B::bool_into_int(self.primitive))
30 }
31
32 /// Convert the bool tensor into an float tensor.
33 pub fn float(self) -> Tensor<B, D> {
34 Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))
35 }
36
37 /// Inverses boolean values.
38 pub fn bool_not(self) -> Self {
39 Tensor::new(B::bool_not(self.primitive))
40 }
41
42 /// Compute the indices of the elements that are non-zero.
43 ///
44 /// # Returns
45 ///
46 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
47 /// the non-zero elements in that dimension.
48 pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {
49 try_read_sync(self.nonzero_async())
50 .expect("Failed to read tensor data synchronously. Try using nonzero_async instead.")
51 }
52
53 /// Compute the indices of the elements that are non-zero.
54 ///
55 /// # Returns
56 ///
57 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
58 /// the non-zero elements in that dimension.
59 pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {
60 B::bool_nonzero(self.primitive)
61 .await
62 .into_iter()
63 .map(Tensor::new)
64 .collect()
65 }
66
67 /// Compute the indices of the elements that are true, grouped by element.
68 ///
69 /// # Returns
70 ///
71 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
72 /// result contains the indices of a non-zero element.
73 pub fn argwhere(self) -> Tensor<B, 2, Int> {
74 try_read_sync(self.argwhere_async())
75 .expect("Failed to read tensor data synchronously. Try using argwhere_async instead.")
76 }
77
78 /// Compute the indices of the elements that are true, grouped by element.
79 ///
80 /// # Returns
81 ///
82 /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the
83 /// result contains the indices of a non-zero element.
84 pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {
85 Tensor::new(B::bool_argwhere(self.primitive).await)
86 }
87
88 /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to
89 /// fill the specified area with a value.
90 fn tri_mask<S: Into<Shape>>(
91 shape: S,
92 tri_part: TriPart,
93 offset: i64,
94 device: &B::Device,
95 ) -> Self {
96 let shape: Shape = shape.into();
97 let height = shape.dims[D - 2];
98 let width = shape.dims[D - 1];
99
100 // Generate row and column index tensors.
101 let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);
102 let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);
103
104 // Prepare shapes for broadcasting.
105 let mut row_shape = [1; D];
106 row_shape[D - 2] = height;
107 let mut col_shape = [1; D];
108 col_shape[D - 1] = width;
109
110 // Reshape for broadcasting.
111 let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));
112 let col_broadcast = col_indices.reshape(Shape::new(col_shape));
113
114 // Broadcasting trick to create a matrix that facilitates comparison for mask generation.
115 let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);
116
117 // Select the appropriate comparison function based on `tri_part`.
118 let compare = match tri_part {
119 TriPart::Upper => Tensor::greater_elem,
120 TriPart::Lower => Tensor::lower_elem,
121 TriPart::Diagonal => Tensor::not_equal_elem,
122 };
123
124 // Generate and return the mask by applying the comparison to the matrix.
125 compare(matrix, 0).unsqueeze()
126 }
127
128 /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified
129 /// area with a value.
130 ///
131 /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.
132 ///
133 /// # Arguments
134 ///
135 /// * `shape`: The shape of the matrix.
136 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift
137 /// towards the upper triangle.
138 /// * `device`: The device on which the tensor will be allocated.
139 ///
140 /// # Returns
141 ///
142 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
143 /// upper triangle taking into account the specified `offset`. All other elements are `true`.
144 ///
145 /// # Example
146 /// ```rust
147 /// use burn_tensor::backend::Backend;
148 /// use burn_tensor::{Tensor, Bool};
149 ///
150 /// fn example<B: Backend>() {
151 /// let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());
152 /// println!("{mask}");
153 /// // [[false, false, false],
154 /// // [true, false, false],
155 /// // [true, true, false]]
156 /// }
157 /// ```
158 pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
159 Self::tri_mask(shape, TriPart::Upper, offset, device)
160 }
161
162 /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified
163 /// area with a value.
164 ///
165 /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.
166 ///
167 /// # Arguments
168 ///
169 /// * `shape`: The shape of the matrix.
170 /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift
171 /// towards the lower triangle.
172 /// * `device`: The device on which the tensor will be allocated.
173 ///
174 /// # Returns
175 ///
176 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
177 /// lower triangle taking into account the specified `offset`. All other elements are `true`.
178 ///
179 /// # Example
180 /// ```rust
181 /// use burn_tensor::backend::Backend;
182 /// use burn_tensor::{Tensor, Bool};
183 ///
184 /// fn example<B: Backend>() {
185 /// let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());
186 /// println!("{mask}");
187 /// // [[false, true, true],
188 /// // [false, false, true],
189 /// // [false, false, false]]
190 /// }
191 /// ```
192 pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
193 Self::tri_mask(shape, TriPart::Lower, offset, device)
194 }
195
196 /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified
197 /// area with a value.
198 ///
199 /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.
200 ///
201 /// # Arguments
202 ///
203 /// * `shape`: The shape of the matrix.
204 /// * `device`: The device on which the tensor will be allocated.
205 ///
206 /// # Returns
207 ///
208 /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the
209 /// diagonal. All other elements are `true`.
210 ///
211 /// # Example
212 /// ```rust
213 /// use burn_tensor::backend::Backend;
214 /// use burn_tensor::{Tensor, Bool};
215 ///
216 /// fn example<B: Backend>() {
217 /// let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());
218 /// println!("{mask}");
219 /// // [[false, true, true],
220 /// // [true, false, true],
221 /// // [true, true, false]]
222 /// }
223 /// ```
224 pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {
225 Self::tri_mask(shape, TriPart::Diagonal, offset, device)
226 }
227}