burn_tensor/tensor/ops/bool_tensor.rs
1use super::{
2 cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign, BoolTensor, Device,
3 FloatTensor, IntTensor,
4};
5use crate::{
6 argwhere_data, backend::Backend, chunk, narrow, split, split_with_sizes, tensor::Shape, Bool,
7 ElementConversion, TensorData, TensorMetadata,
8};
9use alloc::{vec, vec::Vec};
10use core::{future::Future, ops::Range};
11
12/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
13/// for documentation on each function.
14pub trait BoolTensorOps<B: Backend> {
15 /// Creates a new bool tensor.
16 ///
17 /// # Arguments
18 ///
19 /// * `shape` - The shape of the tensor.
20 /// * `device` - The device to create the tensor on.
21 ///
22 /// # Returns
23 ///
24 /// The boolean tensor with the given shape.
25 fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
26
27 /// Converts the tensor to a data structure.
28 ///
29 /// # Arguments
30 ///
31 /// * `tensor` - The tensor.
32 ///
33 /// # Returns
34 ///
35 /// The data structure with the tensor's data.
36 fn bool_into_data(tensor: BoolTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
37
38 /// Creates a tensor from the data structure.
39 ///
40 /// # Arguments
41 ///
42 /// * `data` - The data structure.
43 /// * `device` - The device to create the tensor on.
44 ///
45 /// # Returns
46 ///
47 /// The tensor with the data.
48 fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
49
50 /// Converts bool tensor to int tensor.
51 ///
52 /// # Arguments
53 ///
54 /// * `tensor` - The tensor.
55 ///
56 /// # Returns
57 ///
58 /// The int tensor with the same data as the bool tensor.
59 fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
60
61 /// Converts bool tensor to float tensor.
62 ///
63 /// # Arguments
64 ///
65 /// * `tensor` - The tensor.
66 ///
67 /// # Returns
68 ///
69 /// The float tensor with the same data as the bool tensor.
70 fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
71
72 /// Gets the device of the tensor.
73 ///
74 /// # Arguments
75 ///
76 /// * `tensor` - The tensor.
77 ///
78 /// # Returns
79 ///
80 /// The device of the tensor.
81 fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
82
83 /// Moves the tensor to the device.
84 fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
85
86 /// Reshapes the tensor.
87 ///
88 /// # Arguments
89 ///
90 /// * `tensor` - The tensor.
91 /// * `shape` - The new shape.
92 ///
93 /// # Returns
94 ///
95 /// The tensor with the new shape.
96 fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
97
98 /// Gets the values from the tensor for the given ranges.
99 ///
100 /// # Arguments
101 ///
102 /// * `tensor` - The tensor.
103 /// * `ranges` - The ranges to get the values from.
104 ///
105 /// # Returns
106 ///
107 /// The tensor with the values for the given ranges.
108 fn bool_slice(tensor: BoolTensor<B>, ranges: &[Range<usize>]) -> BoolTensor<B>;
109
110 /// Sets the values in the tensor for the given ranges.
111 ///
112 /// # Arguments
113 ///
114 /// * `tensor` - The tensor.
115 /// * `ranges` - The ranges to set the values for.
116 /// * `value` - The values to set.
117 ///
118 /// # Returns
119 ///
120 /// The tensor with the values set for the given ranges.
121 fn bool_slice_assign(
122 tensor: BoolTensor<B>,
123 ranges: &[Range<usize>],
124 value: BoolTensor<B>,
125 ) -> BoolTensor<B>;
126
127 /// Repeats one dimension of the tensor a given number of times along that dimension.
128 ///
129 /// # Arguments
130 ///
131 /// * `tensor` - The tensor.
132 /// * `dim` - The dimension to repeat.
133 /// * `times` - The number of times to repeat the dimension.
134 ///
135 /// # Returns
136 ///
137 /// The tensor with the dimension repeated.
138 fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
139 repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
140 }
141
142 /// Concatenates the tensors along the given dimension.
143 ///
144 /// # Arguments
145 ///
146 /// * `tensors` - The tensors to concatenate.
147 /// * `dim` - The dimension to concatenate along.
148 ///
149 /// # Returns
150 ///
151 /// The tensor with the tensors concatenated along the given dimension.
152 fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
153 cat_with_slice_assign::<B, Bool>(tensors, dim)
154 }
155
156 /// Equates the two tensors.
157 ///
158 /// # Arguments
159 ///
160 /// * `lhs` - The left hand side tensor.
161 /// * `rhs` - The right hand side tensor.
162 ///
163 /// # Returns
164 ///
165 /// The tensor with the result of the equate.
166 fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
167
168 /// Element-wise non-equality comparison.
169 ///
170 /// # Arguments
171 ///
172 /// * `lhs` - The left hand side tensor.
173 /// * `rhs` - The right hand side tensor.
174 ///
175 /// # Returns
176 ///
177 /// The tensor with the result of the comparison.
178 fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
179 let equal_tensor = B::bool_equal(lhs, rhs);
180 B::bool_not(equal_tensor)
181 }
182
183 /// Inverses boolean values.
184 ///
185 /// # Arguments
186 ///
187 /// * `tensor` - The tensor.
188 ///
189 /// # Returns
190 ///
191 /// The tensor with the result of the negation.
192 fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
193
194 /// Transposes a bool tensor.
195 ///
196 /// # Arguments
197 ///
198 /// * `tensor` - The tensor to transpose.
199 ///
200 /// # Returns
201 ///
202 /// The transposed tensor.
203 fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
204 let ndims = tensor.shape().num_dims();
205 Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
206 }
207
208 /// Swaps two dimensions of a bool tensor.
209 ///
210 /// # Arguments
211 ///
212 /// * `tensor` - The tensor to swap the dimensions of.
213 /// * `dim1` - The first dimension to swap.
214 /// * `dim2` - The second dimension to swap.
215 ///
216 /// # Returns
217 ///
218 /// The tensor with the dimensions swapped.
219 fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
220
221 /// Permutes the dimensions of a tensor.
222 ///
223 /// # Arguments
224 ///
225 /// * `tensor` - The tensor to permute the dimensions of.
226 /// * `axes` - The new order of the dimensions.
227 /// # Returns
228 ///
229 /// The tensor with the dimensions permuted.
230 fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
231
232 /// Reverse the order of elements in a tensor along the given axes.
233 ///
234 /// # Arguments
235 ///
236 /// * `tensor` - The tensor to reverse.
237 /// * `axes` - The axes to reverse.
238 ///
239 /// The tensor with the elements reversed.
240 fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
241
242 /// Returns a new tensor with the given dimension narrowed to the given range.
243 ///
244 /// # Arguments
245 ///
246 /// * `dim` - The dimension along which the tensor will be narrowed.
247 /// * `start` - The starting point of the given range.
248 /// * `length` - The ending point of the given range.
249 /// # Panics
250 ///
251 /// - If the dimension is greater than the number of dimensions of the tensor.
252 /// - If the given range exceeds the number of elements on the given dimension.
253 ///
254 /// # Returns
255 ///
256 /// A new tensor with the given dimension narrowed to the given range.
257 fn bool_narrow(
258 tensor: BoolTensor<B>,
259 dim: usize,
260 start: usize,
261 length: usize,
262 ) -> BoolTensor<B> {
263 narrow::<B, Bool>(tensor, dim, start, length)
264 }
265
266 /// Split the tensor along the given dimension into chunks.
267 ///
268 /// # Arguments
269 ///
270 /// * `tensor` - The tensor.
271 /// * `chunks` - The number of chunks to be produced.
272 /// * `times` - The dimension along which the tensor will be split.
273 ///
274 /// # Returns
275 ///
276 /// A vector of tensors.
277 fn bool_chunk(tensor: BoolTensor<B>, chunks: usize, dim: usize) -> Vec<BoolTensor<B>> {
278 chunk::<B, Bool>(tensor, chunks, dim)
279 }
280
281 /// Split the tensor along the given dimension into chunks of `split_size`.
282 ///
283 /// # Arguments
284 ///
285 /// * `tensor` - The tensor.
286 /// * `split_size` - The size of a single chunk.
287 /// * `times` - The dimension along which the tensor will be split.
288 ///
289 /// # Returns
290 ///
291 /// A vector of tensors.
292 fn bool_split(tensor: BoolTensor<B>, split_size: usize, dim: usize) -> Vec<BoolTensor<B>> {
293 split::<B, Bool>(tensor, split_size, dim)
294 }
295
296 /// Split the tensor along the given dimension into chunks with sizes in
297 /// `dim` according to `split_sizes`.
298 ///
299 /// # Arguments
300 ///
301 /// * `tensor` - The tensor.
302 /// * `split_sizes` - Vector of sizes for each chunk.
303 /// * `times` - The dimension along which the tensor will be split.
304 ///
305 /// # Returns
306 ///
307 /// A vector of tensors.
308 fn bool_split_with_sizes(
309 tensor: BoolTensor<B>,
310 split_sizes: Vec<usize>,
311 dim: usize,
312 ) -> Vec<BoolTensor<B>> {
313 split_with_sizes::<B, Bool>(tensor, split_sizes, dim)
314 }
315
316 /// Tests if any element in the boolean `tensor` evaluates to True.
317 ///
318 /// # Arguments
319 ///
320 /// * `tensor` - The tensor to test.
321 ///
322 /// # Returns
323 ///
324 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
325 fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
326 let sum = B::int_sum(B::bool_into_int(tensor));
327 B::int_greater_elem(sum, 0.elem())
328 }
329
330 /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
331 ///
332 /// # Arguments
333 ///
334 /// * `tensor` - The tensor to test.
335 /// * `dim` - The axis along which to test.
336 ///
337 /// # Returns
338 ///
339 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
340 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
341 /// evaluates to True, False otherwise.
342 fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
343 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
344 B::int_greater_elem(sum, 0.elem())
345 }
346
347 /// Tests if all elements in the boolean `tensor` evaluate to True.
348 ///
349 /// # Arguments
350 ///
351 /// * `tensor` - The tensor to test.
352 ///
353 /// # Returns
354 ///
355 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
356 /// evaluate to True, False otherwise.
357 fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
358 let num_elems = tensor.shape().num_elements();
359 let sum = B::int_sum(B::bool_into_int(tensor));
360 B::int_equal_elem(sum, (num_elems as i32).elem())
361 }
362
363 /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
364 ///
365 /// # Arguments
366 ///
367 /// * `tensor` - The tensor to test.
368 /// * `dim` - The axis along which to test.
369 ///
370 /// # Returns
371 ///
372 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
373 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
374 /// evaluates to True, False otherwise.
375 fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
376 let num_elems = tensor.shape().dims[dim];
377 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
378 B::int_equal_elem(sum, (num_elems as i32).elem())
379 }
380
381 /// Compute the indices of the elements that are non-zero, grouped by element.
382 ///
383 /// # Arguments
384 ///
385 /// * `tensor` - The input tensor.
386 ///
387 /// # Returns
388 ///
389 /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
390 /// Each row contains the indices of a non-zero element.
391 fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
392 async {
393 // Size of each output tensor is variable (= number of nonzero elements in the tensor).
394 // Reading the data to count the number of truth values might cause sync but is required.
395 let device = B::bool_device(&tensor);
396 let data = B::bool_into_data(tensor).await;
397 argwhere_data::<B>(data, &device)
398 }
399 }
400
401 /// Compute the indices of the elements that are non-zero.
402 ///
403 /// # Arguments
404 ///
405 /// * `tensor` - The input tensor.
406 ///
407 /// # Returns
408 ///
409 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
410 /// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
411 fn bool_nonzero(
412 tensor: BoolTensor<B>,
413 ) -> impl Future<Output = Vec<IntTensor<B>>> + 'static + Send {
414 async {
415 let indices = B::bool_argwhere(tensor).await;
416
417 if indices.shape().num_elements() == 0 {
418 // Return empty vec when all elements are zero
419 return vec![];
420 }
421
422 let dims = indices.shape().dims;
423 B::int_chunk(indices, dims[1], 1)
424 .into_iter()
425 .map(|t| B::int_reshape(t, Shape::new([dims[0]])))
426 .collect()
427 }
428 }
429
430 /// Broadcasts the bool `tensor` to the given `shape`.
431 fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
432}