burn_backend/backend/ops/bool_tensor.rs
1use super::{
2 argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};
5use crate::{Backend, TensorData, TensorMetadata};
6use crate::{ExecutionError, Scalar};
7use alloc::vec::Vec;
8use burn_std::{Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16 /// Creates a new bool tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 ///
23 /// # Returns
24 ///
25 /// The boolean tensor with the given shape.
26 fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
27
28 /// Creates a new bool tensor filled false.
29 ///
30 /// # Arguments
31 ///
32 /// * `shape` - The shape of the tensor.
33 /// * `device` - The device to create the tensor on.
34 ///
35 /// # Returns
36 ///
37 /// The boolean tensor filled with false.
38 fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
39
40 /// Creates a new bool tensor filled true.
41 ///
42 /// # Arguments
43 ///
44 /// * `shape` - The shape of the tensor.
45 /// * `device` - The device to create the tensor on.
46 ///
47 /// # Returns
48 ///
49 /// The boolean tensor filled with true.
50 fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
51
52 /// Converts the tensor to a data structure.
53 ///
54 /// # Arguments
55 ///
56 /// * `tensor` - The tensor.
57 ///
58 /// # Returns
59 ///
60 /// The data structure with the tensor's data.
61 fn bool_into_data(
62 tensor: BoolTensor<B>,
63 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
64
65 /// Creates a tensor from the data structure.
66 ///
67 /// # Arguments
68 ///
69 /// * `data` - The data structure.
70 /// * `device` - The device to create the tensor on.
71 ///
72 /// # Returns
73 ///
74 /// The tensor with the data.
75 fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
76
77 /// Converts bool tensor to int tensor.
78 ///
79 /// # Arguments
80 ///
81 /// * `tensor` - The tensor.
82 ///
83 /// # Returns
84 ///
85 /// The int tensor with the same data as the bool tensor.
86 fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
87
88 /// Converts bool tensor to float tensor.
89 ///
90 /// # Arguments
91 ///
92 /// * `tensor` - The tensor.
93 ///
94 /// # Returns
95 ///
96 /// The float tensor with the same data as the bool tensor.
97 fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
98
99 /// Gets the device of the tensor.
100 ///
101 /// # Arguments
102 ///
103 /// * `tensor` - The tensor.
104 ///
105 /// # Returns
106 ///
107 /// The device of the tensor.
108 fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
109
110 /// Moves the tensor to the device.
111 fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
112
113 /// Reshapes the tensor.
114 ///
115 /// # Arguments
116 ///
117 /// * `tensor` - The tensor.
118 /// * `shape` - The new shape.
119 ///
120 /// # Returns
121 ///
122 /// The tensor with the new shape.
123 fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
124
125 /// Gets the values from the tensor for the given ranges.
126 ///
127 /// # Arguments
128 ///
129 /// * `tensor` - The tensor.
130 /// * `slices` - The slices specifying ranges and steps for each dimension.
131 ///
132 /// # Returns
133 ///
134 /// The tensor with the values for the given slices.
135 ///
136 /// # Note
137 ///
138 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
139 /// be passed to this method. Backend implementations do not need to handle empty slices.
140 fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
141
142 /// Sets the values in the tensor for the given ranges.
143 ///
144 /// # Arguments
145 ///
146 /// * `tensor` - The tensor.
147 /// * `ranges` - The ranges to set the values for.
148 /// * `value` - The values to set.
149 ///
150 /// # Returns
151 ///
152 /// The tensor with the values set for the given ranges.
153 ///
154 /// # Note
155 ///
156 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
157 /// high-level tensor API and will not be passed to this method. Backend implementations do
158 /// not need to handle empty slice assignments.
159 fn bool_slice_assign(
160 tensor: BoolTensor<B>,
161 slices: &[Slice],
162 value: BoolTensor<B>,
163 ) -> BoolTensor<B>;
164
165 /// Fills the tensor with values from the value tensor if the mask is true at the given
166 /// indices.
167 ///
168 /// # Arguments
169 ///
170 /// * `tensor` - The tensor.
171 /// * `mask` - The mask.
172 /// * `value` - The value tensor.
173 ///
174 /// # Returns
175 ///
176 /// The tensor with the values filled.
177 fn bool_mask_where(
178 tensor: BoolTensor<B>,
179 mask: BoolTensor<B>,
180 value: BoolTensor<B>,
181 ) -> BoolTensor<B>;
182
183 /// Fills the tensor with the given value if the mask is true at the given indices.
184 ///
185 /// # Arguments
186 ///
187 /// * `tensor` - The tensor.
188 /// * `mask` - The mask.
189 /// * `value` - The value.
190 ///
191 /// # Returns
192 ///
193 /// The tensor with the values filled.
194 fn bool_mask_fill(tensor: BoolTensor<B>, mask: BoolTensor<B>, value: Scalar) -> BoolTensor<B>;
195
196 /// Gather elements from the tensor at the given indices.
197 ///
198 /// # Arguments
199 ///
200 /// * `dim` - The dimension to gather from.
201 /// * `tensor` - The tensor.
202 /// * `indices` - The indices.
203 fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;
204
205 /// Scatter a given value to the tensor at the given indices using boolean or reduction.
206 ///
207 /// # Arguments
208 ///
209 /// * `dim` - The dimension to scatter to.
210 /// * `tensor` - The tensor.
211 /// * `indices` - The indices.
212 /// * `value` - The value.
213 ///
214 /// # Returns
215 ///
216 /// The tensor with the values scattered.
217 fn bool_scatter_or(
218 dim: usize,
219 tensor: BoolTensor<B>,
220 indices: IntTensor<B>,
221 value: BoolTensor<B>,
222 ) -> BoolTensor<B>;
223
224 /// Select tensor elements along the given dimension corresponding to the given indices.
225 ///
226 /// # Arguments
227 ///
228 /// * `tensor` - The tensor to select from.
229 /// * `dim` - The dimension to select from.
230 /// * `indices` - The indices of the elements to select.
231 ///
232 /// # Returns
233 ///
234 /// The tensor with the selected elements.
235 fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
236 // Default implementation: convert to int, select, then convert back to bool
237 let int_tensor = B::bool_into_int(tensor);
238 let selected = B::int_select(int_tensor, dim, indices);
239 B::int_equal_elem(selected, 1.into())
240 }
241
242 /// Assign the selected elements along the given dimension corresponding to the given indices
243 /// to the given value using sum reduction.
244 ///
245 /// # Arguments
246 ///
247 /// * `tensor` - The tensor to assign the values to.
248 /// * `dim` - The dimension to select from.
249 /// * `indices` - The indices of the elements to assign.
250 /// * `value` - The values to assign.
251 ///
252 /// # Returns
253 ///
254 /// The tensor with the assigned values.
255 fn bool_select_or(
256 tensor: BoolTensor<B>,
257 dim: usize,
258 indices: IntTensor<B>,
259 value: BoolTensor<B>,
260 ) -> BoolTensor<B> {
261 // Default implementation: convert to int, select_assign, then convert back to bool
262 let int_tensor = B::bool_into_int(tensor);
263 let int_values = B::bool_into_int(value);
264 let assigned = B::int_select_add(int_tensor, dim, indices, int_values);
265 // After select_assign with sum reduction, any non-zero value should be true
266 B::int_greater_elem(assigned, 0.into())
267 }
268
269 /// Repeats one dimension of the tensor a given number of times along that dimension.
270 ///
271 /// # Arguments
272 ///
273 /// * `tensor` - The tensor.
274 /// * `dim` - The dimension to repeat.
275 /// * `times` - The number of times to repeat the dimension.
276 ///
277 /// # Returns
278 ///
279 /// The tensor with the dimension repeated.
280 fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
281 repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
282 }
283
284 /// Concatenates the tensors along the given dimension.
285 ///
286 /// # Arguments
287 ///
288 /// * `tensors` - The tensors to concatenate.
289 /// * `dim` - The dimension to concatenate along.
290 ///
291 /// # Returns
292 ///
293 /// The tensor with the tensors concatenated along the given dimension.
294 ///
295 /// # Note
296 ///
297 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
298 /// high-level tensor API and will not be passed to this method. Backend implementations do
299 /// not need to handle empty tensors.
300 fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
301 cat_with_slice_assign::<B, Bool>(tensors, dim)
302 }
303
304 /// Equates the two tensors.
305 ///
306 /// # Arguments
307 ///
308 /// * `lhs` - The left hand side tensor.
309 /// * `rhs` - The right hand side tensor.
310 ///
311 /// # Returns
312 ///
313 /// The tensor with the result of the equate.
314 fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
315
316 /// Element-wise non-equality comparison.
317 ///
318 /// # Arguments
319 ///
320 /// * `lhs` - The left hand side tensor.
321 /// * `rhs` - The right hand side tensor.
322 ///
323 /// # Returns
324 ///
325 /// The tensor with the result of the comparison.
326 fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
327 let equal_tensor = B::bool_equal(lhs, rhs);
328 B::bool_not(equal_tensor)
329 }
330
331 /// Element-wise equality comparison with a scalar.
332 ///
333 /// # Arguments
334 ///
335 /// * `lhs` - The left-hand side tensor.
336 /// * `rhs` - The right-hand side scalar.
337 ///
338 /// # Returns
339 ///
340 /// The boolean tensor with the result of the comparison.
341 fn bool_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B>;
342
343 /// Element-wise non-equality comparison with a scalar.
344 ///
345 /// # Arguments
346 ///
347 /// * `lhs` - The left-hand side tensor.
348 /// * `rhs` - The right-hand side scalar.
349 ///
350 /// # Returns
351 ///
352 /// The boolean tensor with the result of the comparison.
353 fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B> {
354 let equal_tensor = B::bool_equal_elem(lhs, rhs);
355 B::bool_not(equal_tensor)
356 }
357
358 /// Inverses boolean values.
359 ///
360 /// # Arguments
361 ///
362 /// * `tensor` - The tensor.
363 ///
364 /// # Returns
365 ///
366 /// The tensor with the result of the negation.
367 fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
368
369 /// Executes the logical and (`&&`) operation on two boolean tensors.
370 ///
371 /// # Arguments
372 ///
373 /// * `lhs` - The left hand side tensor.
374 /// * `rhs` - The right hand side tensor.
375 ///
376 /// # Returns
377 ///
378 /// The tensor with the result of the logical and.
379 fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
380
381 /// Executes the logical or (`||`) operation on two boolean tensors.
382 ///
383 /// # Arguments
384 ///
385 /// * `lhs` - The left hand side tensor.
386 /// * `rhs` - The right hand side tensor.
387 ///
388 /// # Returns
389 ///
390 /// The tensor with the result of the logical or.
391 fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
392
393 /// Element-wise exclusive or.
394 ///
395 /// # Arguments
396 ///
397 /// * `lhs` - The left hand side tensor.
398 /// * `rhs` - The right hand side tensor.
399 ///
400 /// # Returns
401 ///
402 /// The tensor with the result of the comparison.
403 fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
404 Self::bool_not_equal(lhs, rhs)
405 }
406
407 /// Transposes a bool tensor.
408 ///
409 /// # Arguments
410 ///
411 /// * `tensor` - The tensor to transpose.
412 ///
413 /// # Returns
414 ///
415 /// The transposed tensor.
416 fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
417 let ndims = tensor.shape().num_dims();
418 Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
419 }
420
421 /// Swaps two dimensions of a bool tensor.
422 ///
423 /// # Arguments
424 ///
425 /// * `tensor` - The tensor to swap the dimensions of.
426 /// * `dim1` - The first dimension to swap.
427 /// * `dim2` - The second dimension to swap.
428 ///
429 /// # Returns
430 ///
431 /// The tensor with the dimensions swapped.
432 fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
433
434 /// Permutes the dimensions of a tensor.
435 ///
436 /// # Arguments
437 ///
438 /// * `tensor` - The tensor to permute the dimensions of.
439 /// * `axes` - The new order of the dimensions.
440 /// # Returns
441 ///
442 /// The tensor with the dimensions permuted.
443 fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
444
445 /// Reverse the order of elements in a tensor along the given axes.
446 ///
447 /// # Arguments
448 ///
449 /// * `tensor` - The tensor to reverse.
450 /// * `axes` - The axes to reverse.
451 ///
452 /// The tensor with the elements reversed.
453 fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
454
455 /// Tests if any element in the boolean `tensor` evaluates to True.
456 ///
457 /// # Arguments
458 ///
459 /// * `tensor` - The tensor to test.
460 ///
461 /// # Returns
462 ///
463 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
464 fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
465 let sum = B::int_sum(B::bool_into_int(tensor));
466 B::int_greater_elem(sum, 0.into())
467 }
468
469 /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
470 ///
471 /// # Arguments
472 ///
473 /// * `tensor` - The tensor to test.
474 /// * `dim` - The axis along which to test.
475 ///
476 /// # Returns
477 ///
478 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
479 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
480 /// evaluates to True, False otherwise.
481 fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
482 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
483 B::int_greater_elem(sum, 0.into())
484 }
485
486 /// Tests if all elements in the boolean `tensor` evaluate to True.
487 ///
488 /// # Arguments
489 ///
490 /// * `tensor` - The tensor to test.
491 ///
492 /// # Returns
493 ///
494 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
495 /// evaluate to True, False otherwise.
496 fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
497 let num_elems = tensor.shape().num_elements() as i64;
498 let sum = B::int_sum(B::bool_into_int(tensor));
499 B::int_equal_elem(sum, num_elems.into())
500 }
501
502 /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
503 ///
504 /// # Arguments
505 ///
506 /// * `tensor` - The tensor to test.
507 /// * `dim` - The axis along which to test.
508 ///
509 /// # Returns
510 ///
511 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
512 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
513 /// evaluates to True, False otherwise.
514 fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
515 let num_elems = tensor.shape()[dim] as i64;
516 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
517 B::int_equal_elem(sum, num_elems.into())
518 }
519
520 /// Compute the indices of the elements that are non-zero, grouped by element.
521 ///
522 /// # Arguments
523 ///
524 /// * `tensor` - The input tensor.
525 ///
526 /// # Returns
527 ///
528 /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
529 /// Each row contains the indices of a non-zero element.
530 fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
531 async {
532 // Size of each output tensor is variable (= number of nonzero elements in the tensor).
533 // Reading the data to count the number of truth values might cause sync but is required.
534 let device = B::bool_device(&tensor);
535 let data = B::bool_into_data(tensor)
536 .await
537 .expect("Can read the data without error");
538 argwhere_data::<B>(data, &device)
539 }
540 }
541
542 /// Broadcasts the bool `tensor` to the given `shape`.
543 fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
544
545 /// Unfold windows along a dimension.
546 ///
547 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
548 /// where windows are advanced by `step` at each index.
549 ///
550 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
551 ///
552 /// # Arguments
553 ///
554 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
555 /// * `dim` - the selected dim.
556 /// * `size` - the size of each unfolded window.
557 /// * `step` - the step between each window.
558 ///
559 /// # Returns
560 ///
561 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
562 fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
563}