burn_backend/backend/ops/bool_tensor.rs
1use super::{
2 argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::ExecutionError;
5use crate::tensor::{Bool, BoolElem, BoolTensor, Device, FloatTensor, IntTensor};
6use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
7use alloc::vec::Vec;
8use burn_std::{Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16 /// Creates a new bool tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 ///
23 /// # Returns
24 ///
25 /// The boolean tensor with the given shape.
26 fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
27
28 /// Creates a new bool tensor filled false.
29 ///
30 /// # Arguments
31 ///
32 /// * `shape` - The shape of the tensor.
33 /// * `device` - The device to create the tensor on.
34 ///
35 /// # Returns
36 ///
37 /// The boolean tensor filled with false.
38 fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
39
40 /// Creates a new bool tensor filled true.
41 ///
42 /// # Arguments
43 ///
44 /// * `shape` - The shape of the tensor.
45 /// * `device` - The device to create the tensor on.
46 ///
47 /// # Returns
48 ///
49 /// The boolean tensor filled with true.
50 fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
51
52 /// Converts the tensor to a data structure.
53 ///
54 /// # Arguments
55 ///
56 /// * `tensor` - The tensor.
57 ///
58 /// # Returns
59 ///
60 /// The data structure with the tensor's data.
61 fn bool_into_data(
62 tensor: BoolTensor<B>,
63 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
64
65 /// Creates a tensor from the data structure.
66 ///
67 /// # Arguments
68 ///
69 /// * `data` - The data structure.
70 /// * `device` - The device to create the tensor on.
71 ///
72 /// # Returns
73 ///
74 /// The tensor with the data.
75 fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
76
77 /// Converts bool tensor to int tensor.
78 ///
79 /// # Arguments
80 ///
81 /// * `tensor` - The tensor.
82 ///
83 /// # Returns
84 ///
85 /// The int tensor with the same data as the bool tensor.
86 fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
87
88 /// Converts bool tensor to float tensor.
89 ///
90 /// # Arguments
91 ///
92 /// * `tensor` - The tensor.
93 ///
94 /// # Returns
95 ///
96 /// The float tensor with the same data as the bool tensor.
97 fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
98
99 /// Gets the device of the tensor.
100 ///
101 /// # Arguments
102 ///
103 /// * `tensor` - The tensor.
104 ///
105 /// # Returns
106 ///
107 /// The device of the tensor.
108 fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
109
110 /// Moves the tensor to the device.
111 fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
112
113 /// Reshapes the tensor.
114 ///
115 /// # Arguments
116 ///
117 /// * `tensor` - The tensor.
118 /// * `shape` - The new shape.
119 ///
120 /// # Returns
121 ///
122 /// The tensor with the new shape.
123 fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
124
125 /// Gets the values from the tensor for the given ranges.
126 ///
127 /// # Arguments
128 ///
129 /// * `tensor` - The tensor.
130 /// * `slices` - The slices specifying ranges and steps for each dimension.
131 ///
132 /// # Returns
133 ///
134 /// The tensor with the values for the given slices.
135 ///
136 /// # Note
137 ///
138 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
139 /// be passed to this method. Backend implementations do not need to handle empty slices.
140 fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
141
142 /// Sets the values in the tensor for the given ranges.
143 ///
144 /// # Arguments
145 ///
146 /// * `tensor` - The tensor.
147 /// * `ranges` - The ranges to set the values for.
148 /// * `value` - The values to set.
149 ///
150 /// # Returns
151 ///
152 /// The tensor with the values set for the given ranges.
153 ///
154 /// # Note
155 ///
156 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
157 /// high-level tensor API and will not be passed to this method. Backend implementations do
158 /// not need to handle empty slice assignments.
159 fn bool_slice_assign(
160 tensor: BoolTensor<B>,
161 slices: &[Slice],
162 value: BoolTensor<B>,
163 ) -> BoolTensor<B>;
164
165 /// Fills the tensor with values from the value tensor if the mask is true at the given
166 /// indices.
167 ///
168 /// # Arguments
169 ///
170 /// * `tensor` - The tensor.
171 /// * `mask` - The mask.
172 /// * `value` - The value tensor.
173 ///
174 /// # Returns
175 ///
176 /// The tensor with the values filled.
177 fn bool_mask_where(
178 tensor: BoolTensor<B>,
179 mask: BoolTensor<B>,
180 value: BoolTensor<B>,
181 ) -> BoolTensor<B>;
182
183 /// Fills the tensor with the given value if the mask is true at the given indices.
184 ///
185 /// # Arguments
186 ///
187 /// * `tensor` - The tensor.
188 /// * `mask` - The mask.
189 /// * `value` - The value.
190 ///
191 /// # Returns
192 ///
193 /// The tensor with the values filled.
194 fn bool_mask_fill(
195 tensor: BoolTensor<B>,
196 mask: BoolTensor<B>,
197 value: BoolElem<B>,
198 ) -> BoolTensor<B>;
199
200 /// Gather elements from the tensor at the given indices.
201 ///
202 /// # Arguments
203 ///
204 /// * `dim` - The dimension to gather from.
205 /// * `tensor` - The tensor.
206 /// * `indices` - The indices.
207 fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;
208
209 /// Scatter a given value to the tensor at the given indices using boolean or reduction.
210 ///
211 /// # Arguments
212 ///
213 /// * `dim` - The dimension to scatter to.
214 /// * `tensor` - The tensor.
215 /// * `indices` - The indices.
216 /// * `value` - The value.
217 ///
218 /// # Returns
219 ///
220 /// The tensor with the values scattered.
221 fn bool_scatter_or(
222 dim: usize,
223 tensor: BoolTensor<B>,
224 indices: IntTensor<B>,
225 value: BoolTensor<B>,
226 ) -> BoolTensor<B>;
227
228 /// Select tensor elements along the given dimension corresponding to the given indices.
229 ///
230 /// # Arguments
231 ///
232 /// * `tensor` - The tensor to select from.
233 /// * `dim` - The dimension to select from.
234 /// * `indices` - The indices of the elements to select.
235 ///
236 /// # Returns
237 ///
238 /// The tensor with the selected elements.
239 fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
240 // Default implementation: convert to int, select, then convert back to bool
241 let int_tensor = B::bool_into_int(tensor);
242 let selected = B::int_select(int_tensor, dim, indices);
243 B::int_equal_elem(selected, 1_i32.elem())
244 }
245
246 /// Assign the selected elements along the given dimension corresponding to the given indices
247 /// to the given value using sum reduction.
248 ///
249 /// # Arguments
250 ///
251 /// * `tensor` - The tensor to assign the values to.
252 /// * `dim` - The dimension to select from.
253 /// * `indices` - The indices of the elements to assign.
254 /// * `value` - The values to assign.
255 ///
256 /// # Returns
257 ///
258 /// The tensor with the assigned values.
259 fn bool_select_or(
260 tensor: BoolTensor<B>,
261 dim: usize,
262 indices: IntTensor<B>,
263 value: BoolTensor<B>,
264 ) -> BoolTensor<B> {
265 // Default implementation: convert to int, select_assign, then convert back to bool
266 let int_tensor = B::bool_into_int(tensor);
267 let int_values = B::bool_into_int(value);
268 let assigned = B::int_select_add(int_tensor, dim, indices, int_values);
269 // After select_assign with sum reduction, any non-zero value should be true
270 B::int_greater_elem(assigned, 0_i32.elem())
271 }
272
273 /// Repeats one dimension of the tensor a given number of times along that dimension.
274 ///
275 /// # Arguments
276 ///
277 /// * `tensor` - The tensor.
278 /// * `dim` - The dimension to repeat.
279 /// * `times` - The number of times to repeat the dimension.
280 ///
281 /// # Returns
282 ///
283 /// The tensor with the dimension repeated.
284 fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
285 repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
286 }
287
288 /// Concatenates the tensors along the given dimension.
289 ///
290 /// # Arguments
291 ///
292 /// * `tensors` - The tensors to concatenate.
293 /// * `dim` - The dimension to concatenate along.
294 ///
295 /// # Returns
296 ///
297 /// The tensor with the tensors concatenated along the given dimension.
298 ///
299 /// # Note
300 ///
301 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
302 /// high-level tensor API and will not be passed to this method. Backend implementations do
303 /// not need to handle empty tensors.
304 fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
305 cat_with_slice_assign::<B, Bool>(tensors, dim)
306 }
307
308 /// Equates the two tensors.
309 ///
310 /// # Arguments
311 ///
312 /// * `lhs` - The left hand side tensor.
313 /// * `rhs` - The right hand side tensor.
314 ///
315 /// # Returns
316 ///
317 /// The tensor with the result of the equate.
318 fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
319
320 /// Element-wise non-equality comparison.
321 ///
322 /// # Arguments
323 ///
324 /// * `lhs` - The left hand side tensor.
325 /// * `rhs` - The right hand side tensor.
326 ///
327 /// # Returns
328 ///
329 /// The tensor with the result of the comparison.
330 fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
331 let equal_tensor = B::bool_equal(lhs, rhs);
332 B::bool_not(equal_tensor)
333 }
334
335 /// Element-wise equality comparison with a scalar.
336 ///
337 /// # Arguments
338 ///
339 /// * `lhs` - The left-hand side tensor.
340 /// * `rhs` - The right-hand side scalar.
341 ///
342 /// # Returns
343 ///
344 /// The boolean tensor with the result of the comparison.
345 fn bool_equal_elem(lhs: BoolTensor<B>, rhs: BoolElem<B>) -> BoolTensor<B>;
346
347 /// Element-wise non-equality comparison with a scalar.
348 ///
349 /// # Arguments
350 ///
351 /// * `lhs` - The left-hand side tensor.
352 /// * `rhs` - The right-hand side scalar.
353 ///
354 /// # Returns
355 ///
356 /// The boolean tensor with the result of the comparison.
357 fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: BoolElem<B>) -> BoolTensor<B> {
358 let equal_tensor = B::bool_equal_elem(lhs, rhs);
359 B::bool_not(equal_tensor)
360 }
361
362 /// Inverses boolean values.
363 ///
364 /// # Arguments
365 ///
366 /// * `tensor` - The tensor.
367 ///
368 /// # Returns
369 ///
370 /// The tensor with the result of the negation.
371 fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
372
373 /// Executes the logical and (`&&`) operation on two boolean tensors.
374 ///
375 /// # Arguments
376 ///
377 /// * `lhs` - The left hand side tensor.
378 /// * `rhs` - The right hand side tensor.
379 ///
380 /// # Returns
381 ///
382 /// The tensor with the result of the logical and.
383 fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
384
385 /// Executes the logical or (`||`) operation on two boolean tensors.
386 ///
387 /// # Arguments
388 ///
389 /// * `lhs` - The left hand side tensor.
390 /// * `rhs` - The right hand side tensor.
391 ///
392 /// # Returns
393 ///
394 /// The tensor with the result of the logical or.
395 fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
396
397 /// Element-wise exclusive or.
398 ///
399 /// # Arguments
400 ///
401 /// * `lhs` - The left hand side tensor.
402 /// * `rhs` - The right hand side tensor.
403 ///
404 /// # Returns
405 ///
406 /// The tensor with the result of the comparison.
407 fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
408 Self::bool_not_equal(lhs, rhs)
409 }
410
411 /// Transposes a bool tensor.
412 ///
413 /// # Arguments
414 ///
415 /// * `tensor` - The tensor to transpose.
416 ///
417 /// # Returns
418 ///
419 /// The transposed tensor.
420 fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
421 let ndims = tensor.shape().num_dims();
422 Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
423 }
424
425 /// Swaps two dimensions of a bool tensor.
426 ///
427 /// # Arguments
428 ///
429 /// * `tensor` - The tensor to swap the dimensions of.
430 /// * `dim1` - The first dimension to swap.
431 /// * `dim2` - The second dimension to swap.
432 ///
433 /// # Returns
434 ///
435 /// The tensor with the dimensions swapped.
436 fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
437
438 /// Permutes the dimensions of a tensor.
439 ///
440 /// # Arguments
441 ///
442 /// * `tensor` - The tensor to permute the dimensions of.
443 /// * `axes` - The new order of the dimensions.
444 /// # Returns
445 ///
446 /// The tensor with the dimensions permuted.
447 fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
448
449 /// Reverse the order of elements in a tensor along the given axes.
450 ///
451 /// # Arguments
452 ///
453 /// * `tensor` - The tensor to reverse.
454 /// * `axes` - The axes to reverse.
455 ///
456 /// The tensor with the elements reversed.
457 fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
458
459 /// Tests if any element in the boolean `tensor` evaluates to True.
460 ///
461 /// # Arguments
462 ///
463 /// * `tensor` - The tensor to test.
464 ///
465 /// # Returns
466 ///
467 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
468 fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
469 let sum = B::int_sum(B::bool_into_int(tensor));
470 B::int_greater_elem(sum, 0.elem())
471 }
472
473 /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
474 ///
475 /// # Arguments
476 ///
477 /// * `tensor` - The tensor to test.
478 /// * `dim` - The axis along which to test.
479 ///
480 /// # Returns
481 ///
482 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
483 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
484 /// evaluates to True, False otherwise.
485 fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
486 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
487 B::int_greater_elem(sum, 0.elem())
488 }
489
490 /// Tests if all elements in the boolean `tensor` evaluate to True.
491 ///
492 /// # Arguments
493 ///
494 /// * `tensor` - The tensor to test.
495 ///
496 /// # Returns
497 ///
498 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
499 /// evaluate to True, False otherwise.
500 fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
501 let num_elems = tensor.shape().num_elements();
502 let sum = B::int_sum(B::bool_into_int(tensor));
503 B::int_equal_elem(sum, (num_elems as i32).elem())
504 }
505
506 /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
507 ///
508 /// # Arguments
509 ///
510 /// * `tensor` - The tensor to test.
511 /// * `dim` - The axis along which to test.
512 ///
513 /// # Returns
514 ///
515 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
516 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
517 /// evaluates to True, False otherwise.
518 fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
519 let num_elems = tensor.shape().dims[dim];
520 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
521 B::int_equal_elem(sum, (num_elems as i32).elem())
522 }
523
524 /// Compute the indices of the elements that are non-zero, grouped by element.
525 ///
526 /// # Arguments
527 ///
528 /// * `tensor` - The input tensor.
529 ///
530 /// # Returns
531 ///
532 /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
533 /// Each row contains the indices of a non-zero element.
534 fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
535 async {
536 // Size of each output tensor is variable (= number of nonzero elements in the tensor).
537 // Reading the data to count the number of truth values might cause sync but is required.
538 let device = B::bool_device(&tensor);
539 let data = B::bool_into_data(tensor)
540 .await
541 .expect("Can read the data without error");
542 argwhere_data::<B>(data, &device)
543 }
544 }
545
546 /// Broadcasts the bool `tensor` to the given `shape`.
547 fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
548
549 /// Unfold windows along a dimension.
550 ///
551 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
552 /// where windows are advanced by `step` at each index.
553 ///
554 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
555 ///
556 /// # Arguments
557 ///
558 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
559 /// * `dim` - the selected dim.
560 /// * `size` - the size of each unfolded window.
561 /// * `step` - the step between each window.
562 ///
563 /// # Returns
564 ///
565 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
566 fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
567}