burn_backend/backend/ops/bool_tensor.rs
1use super::{
2 argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};
5use crate::{Backend, TensorData, TensorMetadata, get_device_settings};
6use crate::{ExecutionError, Scalar};
7use alloc::vec::Vec;
8use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16 /// Creates a new bool tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 /// * `dtype` - The target data type.
23 ///
24 /// # Returns
25 ///
26 /// The boolean tensor with the given shape.
27 fn bool_empty(shape: Shape, device: &Device<B>, dtype: BoolDType) -> BoolTensor<B>;
28
29 /// Creates a new bool tensor filled false.
30 ///
31 /// # Arguments
32 ///
33 /// * `shape` - The shape of the tensor.
34 /// * `device` - The device to create the tensor on.
35 /// * `dtype` - The target data type.
36 ///
37 /// # Returns
38 ///
39 /// The boolean tensor filled with false.
40 fn bool_zeros(shape: Shape, device: &Device<B>, dtype: BoolDType) -> BoolTensor<B>;
41
42 /// Creates a new bool tensor filled true.
43 ///
44 /// # Arguments
45 ///
46 /// * `shape` - The shape of the tensor.
47 /// * `device` - The device to create the tensor on.
48 /// * `dtype` - The target data type.
49 ///
50 /// # Returns
51 ///
52 /// The boolean tensor filled with true.
53 fn bool_ones(shape: Shape, device: &Device<B>, dtype: BoolDType) -> BoolTensor<B>;
54
55 /// Converts the tensor to a data structure.
56 ///
57 /// # Arguments
58 ///
59 /// * `tensor` - The tensor.
60 ///
61 /// # Returns
62 ///
63 /// The data structure with the tensor's data.
64 fn bool_into_data(
65 tensor: BoolTensor<B>,
66 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
67
68 /// Creates a tensor from the data structure.
69 ///
70 /// # Arguments
71 ///
72 /// * `data` - The data structure.
73 /// * `device` - The device to create the tensor on.
74 ///
75 /// # Returns
76 ///
77 /// The tensor with the data.
78 fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
79
80 /// Converts bool tensor to int tensor.
81 ///
82 /// # Arguments
83 ///
84 /// * `tensor` - The tensor.
85 /// * `out_dtype` - The output tensor dtype.
86 ///
87 /// # Returns
88 ///
89 /// The int tensor with the same data as the bool tensor.
90 fn bool_into_int(tensor: BoolTensor<B>, out_dtype: IntDType) -> IntTensor<B>;
91
92 /// Converts bool tensor to float tensor.
93 ///
94 /// # Arguments
95 ///
96 /// * `tensor` - The tensor.
97 /// * `out_dtype` - The output tensor dtype.
98 ///
99 /// # Returns
100 ///
101 /// The float tensor with the same data as the bool tensor.
102 fn bool_into_float(tensor: BoolTensor<B>, out_dtype: FloatDType) -> FloatTensor<B>;
103
104 /// Gets the device of the tensor.
105 ///
106 /// # Arguments
107 ///
108 /// * `tensor` - The tensor.
109 ///
110 /// # Returns
111 ///
112 /// The device of the tensor.
113 fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
114
115 /// Moves the tensor to the device.
116 fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
117
118 /// Reshapes the tensor.
119 ///
120 /// # Arguments
121 ///
122 /// * `tensor` - The tensor.
123 /// * `shape` - The new shape.
124 ///
125 /// # Returns
126 ///
127 /// The tensor with the new shape.
128 fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
129
130 /// Gets the values from the tensor for the given ranges.
131 ///
132 /// # Arguments
133 ///
134 /// * `tensor` - The tensor.
135 /// * `slices` - The slices specifying ranges and steps for each dimension.
136 ///
137 /// # Returns
138 ///
139 /// The tensor with the values for the given slices.
140 ///
141 /// # Note
142 ///
143 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
144 /// be passed to this method. Backend implementations do not need to handle empty slices.
145 fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
146
147 /// Sets the values in the tensor for the given ranges.
148 ///
149 /// # Arguments
150 ///
151 /// * `tensor` - The tensor.
152 /// * `ranges` - The ranges to set the values for.
153 /// * `value` - The values to set.
154 ///
155 /// # Returns
156 ///
157 /// The tensor with the values set for the given ranges.
158 ///
159 /// # Note
160 ///
161 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
162 /// high-level tensor API and will not be passed to this method. Backend implementations do
163 /// not need to handle empty slice assignments.
164 fn bool_slice_assign(
165 tensor: BoolTensor<B>,
166 slices: &[Slice],
167 value: BoolTensor<B>,
168 ) -> BoolTensor<B>;
169
170 /// Fills the tensor with values from the value tensor if the mask is true at the given
171 /// indices.
172 ///
173 /// # Arguments
174 ///
175 /// * `tensor` - The tensor.
176 /// * `mask` - The mask.
177 /// * `value` - The value tensor.
178 ///
179 /// # Returns
180 ///
181 /// The tensor with the values filled.
182 fn bool_mask_where(
183 tensor: BoolTensor<B>,
184 mask: BoolTensor<B>,
185 value: BoolTensor<B>,
186 ) -> BoolTensor<B>;
187
188 /// Fills the tensor with the given value if the mask is true at the given indices.
189 ///
190 /// # Arguments
191 ///
192 /// * `tensor` - The tensor.
193 /// * `mask` - The mask.
194 /// * `value` - The value.
195 ///
196 /// # Returns
197 ///
198 /// The tensor with the values filled.
199 fn bool_mask_fill(tensor: BoolTensor<B>, mask: BoolTensor<B>, value: Scalar) -> BoolTensor<B>;
200
201 /// Gather elements from the tensor at the given indices.
202 ///
203 /// # Arguments
204 ///
205 /// * `dim` - The dimension to gather from.
206 /// * `tensor` - The tensor.
207 /// * `indices` - The indices.
208 fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;
209
210 /// Scatter a given value to the tensor at the given indices using boolean or reduction.
211 ///
212 /// # Arguments
213 ///
214 /// * `dim` - The dimension to scatter to.
215 /// * `tensor` - The tensor.
216 /// * `indices` - The indices.
217 /// * `value` - The value.
218 ///
219 /// # Returns
220 ///
221 /// The tensor with the values scattered.
222 fn bool_scatter_or(
223 dim: usize,
224 tensor: BoolTensor<B>,
225 indices: IntTensor<B>,
226 value: BoolTensor<B>,
227 ) -> BoolTensor<B>;
228
229 /// Select tensor elements along the given dimension corresponding to the given indices.
230 ///
231 /// # Arguments
232 ///
233 /// * `tensor` - The tensor to select from.
234 /// * `dim` - The dimension to select from.
235 /// * `indices` - The indices of the elements to select.
236 ///
237 /// # Returns
238 ///
239 /// The tensor with the selected elements.
240 fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B>;
241
242 /// Assign the selected elements along the given dimension corresponding to the given indices
243 /// to the given value using sum reduction.
244 ///
245 /// # Arguments
246 ///
247 /// * `tensor` - The tensor to assign the values to.
248 /// * `dim` - The dimension to select from.
249 /// * `indices` - The indices of the elements to assign.
250 /// * `value` - The values to assign.
251 ///
252 /// # Returns
253 ///
254 /// The tensor with the assigned values.
255 fn bool_select_or(
256 tensor: BoolTensor<B>,
257 dim: usize,
258 indices: IntTensor<B>,
259 value: BoolTensor<B>,
260 ) -> BoolTensor<B>;
261
262 /// Repeats one dimension of the tensor a given number of times along that dimension.
263 ///
264 /// # Arguments
265 ///
266 /// * `tensor` - The tensor.
267 /// * `dim` - The dimension to repeat.
268 /// * `times` - The number of times to repeat the dimension.
269 ///
270 /// # Returns
271 ///
272 /// The tensor with the dimension repeated.
273 fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
274 repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
275 }
276
277 /// Concatenates the tensors along the given dimension.
278 ///
279 /// # Arguments
280 ///
281 /// * `tensors` - The tensors to concatenate.
282 /// * `dim` - The dimension to concatenate along.
283 ///
284 /// # Returns
285 ///
286 /// The tensor with the tensors concatenated along the given dimension.
287 ///
288 /// # Note
289 ///
290 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
291 /// high-level tensor API and will not be passed to this method. Backend implementations do
292 /// not need to handle empty tensors.
293 fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
294 cat_with_slice_assign::<B, Bool>(tensors, dim)
295 }
296
297 /// Equates the two tensors.
298 ///
299 /// # Arguments
300 ///
301 /// * `lhs` - The left hand side tensor.
302 /// * `rhs` - The right hand side tensor.
303 ///
304 /// # Returns
305 ///
306 /// The tensor with the result of the equate.
307 fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
308
309 /// Element-wise non-equality comparison.
310 ///
311 /// # Arguments
312 ///
313 /// * `lhs` - The left hand side tensor.
314 /// * `rhs` - The right hand side tensor.
315 ///
316 /// # Returns
317 ///
318 /// The tensor with the result of the comparison.
319 fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
320 let equal_tensor = B::bool_equal(lhs, rhs);
321 B::bool_not(equal_tensor)
322 }
323
324 /// Element-wise equality comparison with a scalar.
325 ///
326 /// # Arguments
327 ///
328 /// * `lhs` - The left-hand side tensor.
329 /// * `rhs` - The right-hand side scalar.
330 ///
331 /// # Returns
332 ///
333 /// The boolean tensor with the result of the comparison.
334 fn bool_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B>;
335
336 /// Element-wise non-equality comparison with a scalar.
337 ///
338 /// # Arguments
339 ///
340 /// * `lhs` - The left-hand side tensor.
341 /// * `rhs` - The right-hand side scalar.
342 ///
343 /// # Returns
344 ///
345 /// The boolean tensor with the result of the comparison.
346 fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B> {
347 let equal_tensor = B::bool_equal_elem(lhs, rhs);
348 B::bool_not(equal_tensor)
349 }
350
351 /// Inverses boolean values.
352 ///
353 /// # Arguments
354 ///
355 /// * `tensor` - The tensor.
356 ///
357 /// # Returns
358 ///
359 /// The tensor with the result of the negation.
360 fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
361
362 /// Executes the logical and (`&&`) operation on two boolean tensors.
363 ///
364 /// # Arguments
365 ///
366 /// * `lhs` - The left hand side tensor.
367 /// * `rhs` - The right hand side tensor.
368 ///
369 /// # Returns
370 ///
371 /// The tensor with the result of the logical and.
372 fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
373
374 /// Executes the logical or (`||`) operation on two boolean tensors.
375 ///
376 /// # Arguments
377 ///
378 /// * `lhs` - The left hand side tensor.
379 /// * `rhs` - The right hand side tensor.
380 ///
381 /// # Returns
382 ///
383 /// The tensor with the result of the logical or.
384 fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
385
386 /// Element-wise exclusive or.
387 ///
388 /// # Arguments
389 ///
390 /// * `lhs` - The left hand side tensor.
391 /// * `rhs` - The right hand side tensor.
392 ///
393 /// # Returns
394 ///
395 /// The tensor with the result of the comparison.
396 fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
397 Self::bool_not_equal(lhs, rhs)
398 }
399
400 /// Transposes a bool tensor.
401 ///
402 /// # Arguments
403 ///
404 /// * `tensor` - The tensor to transpose.
405 ///
406 /// # Returns
407 ///
408 /// The transposed tensor.
409 fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
410 let ndims = tensor.shape().num_dims();
411 Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
412 }
413
414 /// Swaps two dimensions of a bool tensor.
415 ///
416 /// # Arguments
417 ///
418 /// * `tensor` - The tensor to swap the dimensions of.
419 /// * `dim1` - The first dimension to swap.
420 /// * `dim2` - The second dimension to swap.
421 ///
422 /// # Returns
423 ///
424 /// The tensor with the dimensions swapped.
425 fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
426
427 /// Permutes the dimensions of a tensor.
428 ///
429 /// # Arguments
430 ///
431 /// * `tensor` - The tensor to permute the dimensions of.
432 /// * `axes` - The new order of the dimensions.
433 /// # Returns
434 ///
435 /// The tensor with the dimensions permuted.
436 fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
437
438 /// Reverse the order of elements in a tensor along the given axes.
439 ///
440 /// # Arguments
441 ///
442 /// * `tensor` - The tensor to reverse.
443 /// * `axes` - The axes to reverse.
444 ///
445 /// The tensor with the elements reversed.
446 fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
447
448 /// Tests if any element in the boolean `tensor` evaluates to True.
449 ///
450 /// # Arguments
451 ///
452 /// * `tensor` - The tensor to test.
453 ///
454 /// # Returns
455 ///
456 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
457 fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
458 let dtype = tensor.dtype();
459 let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
460 let sum = B::int_sum(B::bool_into_int(tensor, int_dtype));
461 B::int_greater_elem(sum, 0.into(), dtype.into())
462 }
463
464 /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
465 ///
466 /// # Arguments
467 ///
468 /// * `tensor` - The tensor to test.
469 /// * `dim` - The axis along which to test.
470 ///
471 /// # Returns
472 ///
473 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
474 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
475 /// evaluates to True, False otherwise.
476 fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
477 let dtype = tensor.dtype();
478 let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
479 let sum = B::int_sum_dim(B::bool_into_int(tensor, int_dtype), dim);
480 B::int_greater_elem(sum, 0.into(), dtype.into())
481 }
482
483 /// Tests if all elements in the boolean `tensor` evaluate to True.
484 ///
485 /// # Arguments
486 ///
487 /// * `tensor` - The tensor to test.
488 ///
489 /// # Returns
490 ///
491 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
492 /// evaluate to True, False otherwise.
493 fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
494 let dtype = tensor.dtype();
495 let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
496 let num_elems = tensor.shape().num_elements() as i64;
497 let sum = B::int_sum(B::bool_into_int(tensor, int_dtype));
498 B::int_equal_elem(sum, num_elems.into(), dtype.into())
499 }
500
501 /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
502 ///
503 /// # Arguments
504 ///
505 /// * `tensor` - The tensor to test.
506 /// * `dim` - The axis along which to test.
507 ///
508 /// # Returns
509 ///
510 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
511 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
512 /// evaluates to True, False otherwise.
513 fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
514 let dtype = tensor.dtype();
515 let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
516 let num_elems = tensor.shape()[dim] as i64;
517 let sum = B::int_sum_dim(B::bool_into_int(tensor, int_dtype), dim);
518 B::int_equal_elem(sum, num_elems.into(), dtype.into())
519 }
520
521 /// Compute the indices of the elements that are non-zero, grouped by element.
522 ///
523 /// # Arguments
524 ///
525 /// * `tensor` - The input tensor.
526 /// * `out_dtype` - The output tensor dtype.
527 ///
528 /// # Returns
529 ///
530 /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
531 /// Each row contains the indices of a non-zero element.
532 fn bool_argwhere(
533 tensor: BoolTensor<B>,
534 out_dtype: IntDType,
535 ) -> impl Future<Output = IntTensor<B>> + 'static + Send {
536 async move {
537 // Size of each output tensor is variable (= number of nonzero elements in the tensor).
538 // Reading the data to count the number of truth values might cause sync but is required.
539 let device = B::bool_device(&tensor);
540 let data = B::bool_into_data(tensor)
541 .await
542 .expect("Can read the data without error");
543 argwhere_data::<B>(data, &device, out_dtype)
544 }
545 }
546
547 /// Broadcasts the bool `tensor` to the given `shape`.
548 fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
549
550 /// Unfold windows along a dimension.
551 ///
552 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
553 /// where windows are advanced by `step` at each index.
554 ///
555 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
556 ///
557 /// # Arguments
558 ///
559 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
560 /// * `dim` - the selected dim.
561 /// * `size` - the size of each unfolded window.
562 /// * `step` - the step between each window.
563 ///
564 /// # Returns
565 ///
566 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
567 fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
568}