burn_backend/backend/ops/
bool_tensor.rs

1use super::{
2    argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::ExecutionError;
5use crate::tensor::{Bool, BoolElem, BoolTensor, Device, FloatTensor, IntTensor};
6use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
7use alloc::vec::Vec;
8use burn_std::{Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16    /// Creates a new bool tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    ///
23    /// # Returns
24    ///
25    /// The boolean tensor with the given shape.
26    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
27
28    /// Creates a new bool tensor filled false.
29    ///
30    /// # Arguments
31    ///
32    /// * `shape` - The shape of the tensor.
33    /// * `device` - The device to create the tensor on.
34    ///
35    /// # Returns
36    ///
37    /// The boolean tensor filled with false.
38    fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
39
40    /// Creates a new bool tensor filled true.
41    ///
42    /// # Arguments
43    ///
44    /// * `shape` - The shape of the tensor.
45    /// * `device` - The device to create the tensor on.
46    ///
47    /// # Returns
48    ///
49    /// The boolean tensor filled with true.
50    fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
51
52    /// Converts the tensor to a data structure.
53    ///
54    /// # Arguments
55    ///
56    /// * `tensor` - The tensor.
57    ///
58    /// # Returns
59    ///
60    /// The data structure with the tensor's data.
61    fn bool_into_data(
62        tensor: BoolTensor<B>,
63    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
64
65    /// Creates a tensor from the data structure.
66    ///
67    /// # Arguments
68    ///
69    /// * `data` - The data structure.
70    /// * `device` - The device to create the tensor on.
71    ///
72    /// # Returns
73    ///
74    /// The tensor with the data.
75    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
76
77    /// Converts bool tensor to int tensor.
78    ///
79    /// # Arguments
80    ///
81    /// * `tensor` - The tensor.
82    ///
83    /// # Returns
84    ///
85    /// The int tensor with the same data as the bool tensor.
86    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
87
88    /// Converts bool tensor to float tensor.
89    ///
90    /// # Arguments
91    ///
92    /// * `tensor` - The tensor.
93    ///
94    /// # Returns
95    ///
96    /// The float tensor with the same data as the bool tensor.
97    fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
98
99    /// Gets the device of the tensor.
100    ///
101    /// # Arguments
102    ///
103    /// * `tensor` - The tensor.
104    ///
105    /// # Returns
106    ///
107    /// The device of the tensor.
108    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
109
110    /// Moves the tensor to the device.
111    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
112
113    /// Reshapes the tensor.
114    ///
115    /// # Arguments
116    ///
117    /// * `tensor` - The tensor.
118    /// * `shape` - The new shape.
119    ///
120    /// # Returns
121    ///
122    /// The tensor with the new shape.
123    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
124
125    /// Gets the values from the tensor for the given ranges.
126    ///
127    /// # Arguments
128    ///
129    /// * `tensor` - The tensor.
130    /// * `slices` - The slices specifying ranges and steps for each dimension.
131    ///
132    /// # Returns
133    ///
134    /// The tensor with the values for the given slices.
135    ///
136    /// # Note
137    ///
138    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
139    /// be passed to this method. Backend implementations do not need to handle empty slices.
140    fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
141
142    /// Sets the values in the tensor for the given ranges.
143    ///
144    /// # Arguments
145    ///
146    /// * `tensor` - The tensor.
147    /// * `ranges` - The ranges to set the values for.
148    /// * `value` - The values to set.
149    ///
150    /// # Returns
151    ///
152    /// The tensor with the values set for the given ranges.
153    ///
154    /// # Note
155    ///
156    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
157    /// high-level tensor API and will not be passed to this method. Backend implementations do
158    /// not need to handle empty slice assignments.
159    fn bool_slice_assign(
160        tensor: BoolTensor<B>,
161        slices: &[Slice],
162        value: BoolTensor<B>,
163    ) -> BoolTensor<B>;
164
165    /// Fills the tensor with values from the value tensor if the mask is true at the given
166    /// indices.
167    ///
168    /// # Arguments
169    ///
170    /// * `tensor` - The tensor.
171    /// * `mask` - The mask.
172    /// * `value` - The value tensor.
173    ///
174    /// # Returns
175    ///
176    /// The tensor with the values filled.
177    fn bool_mask_where(
178        tensor: BoolTensor<B>,
179        mask: BoolTensor<B>,
180        value: BoolTensor<B>,
181    ) -> BoolTensor<B>;
182
183    /// Fills the tensor with the given value if the mask is true at the given indices.
184    ///
185    /// # Arguments
186    ///
187    /// * `tensor` - The tensor.
188    /// * `mask` - The mask.
189    /// * `value` - The value.
190    ///
191    /// # Returns
192    ///
193    /// The tensor with the values filled.
194    fn bool_mask_fill(
195        tensor: BoolTensor<B>,
196        mask: BoolTensor<B>,
197        value: BoolElem<B>,
198    ) -> BoolTensor<B>;
199
200    /// Gather elements from the tensor at the given indices.
201    ///
202    /// # Arguments
203    ///
204    /// * `dim` - The dimension to gather from.
205    /// * `tensor` - The tensor.
206    /// * `indices` - The indices.
207    fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;
208
209    /// Scatter a given value to the tensor at the given indices using boolean or reduction.
210    ///
211    /// # Arguments
212    ///
213    /// * `dim` - The dimension to scatter to.
214    /// * `tensor` - The tensor.
215    /// * `indices` - The indices.
216    /// * `value` - The value.
217    ///
218    /// # Returns
219    ///
220    /// The tensor with the values scattered.
221    fn bool_scatter_or(
222        dim: usize,
223        tensor: BoolTensor<B>,
224        indices: IntTensor<B>,
225        value: BoolTensor<B>,
226    ) -> BoolTensor<B>;
227
228    /// Select tensor elements along the given dimension corresponding to the given indices.
229    ///
230    /// # Arguments
231    ///
232    /// * `tensor` - The tensor to select from.
233    /// * `dim` - The dimension to select from.
234    /// * `indices` - The indices of the elements to select.
235    ///
236    /// # Returns
237    ///
238    /// The tensor with the selected elements.
239    fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
240        // Default implementation: convert to int, select, then convert back to bool
241        let int_tensor = B::bool_into_int(tensor);
242        let selected = B::int_select(int_tensor, dim, indices);
243        B::int_equal_elem(selected, 1_i32.elem())
244    }
245
246    /// Assign the selected elements along the given dimension corresponding to the given indices
247    /// to the given value using sum reduction.
248    ///
249    /// # Arguments
250    ///
251    /// * `tensor` - The tensor to assign the values to.
252    /// * `dim` - The dimension to select from.
253    /// * `indices` - The indices of the elements to assign.
254    /// * `value` - The values to assign.
255    ///
256    /// # Returns
257    ///
258    /// The tensor with the assigned values.
259    fn bool_select_or(
260        tensor: BoolTensor<B>,
261        dim: usize,
262        indices: IntTensor<B>,
263        value: BoolTensor<B>,
264    ) -> BoolTensor<B> {
265        // Default implementation: convert to int, select_assign, then convert back to bool
266        let int_tensor = B::bool_into_int(tensor);
267        let int_values = B::bool_into_int(value);
268        let assigned = B::int_select_add(int_tensor, dim, indices, int_values);
269        // After select_assign with sum reduction, any non-zero value should be true
270        B::int_greater_elem(assigned, 0_i32.elem())
271    }
272
273    /// Repeats one dimension of the tensor a given number of times along that dimension.
274    ///
275    /// # Arguments
276    ///
277    /// * `tensor` - The tensor.
278    /// * `dim` - The dimension to repeat.
279    /// * `times` - The number of times to repeat the dimension.
280    ///
281    /// # Returns
282    ///
283    /// The tensor with the dimension repeated.
284    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
285        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
286    }
287
288    /// Concatenates the tensors along the given dimension.
289    ///
290    /// # Arguments
291    ///
292    /// * `tensors` - The tensors to concatenate.
293    /// * `dim` - The dimension to concatenate along.
294    ///
295    /// # Returns
296    ///
297    /// The tensor with the tensors concatenated along the given dimension.
298    ///
299    /// # Note
300    ///
301    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
302    /// high-level tensor API and will not be passed to this method. Backend implementations do
303    /// not need to handle empty tensors.
304    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
305        cat_with_slice_assign::<B, Bool>(tensors, dim)
306    }
307
308    /// Equates the two tensors.
309    ///
310    /// # Arguments
311    ///
312    /// * `lhs` - The left hand side tensor.
313    /// * `rhs` - The right hand side tensor.
314    ///
315    /// # Returns
316    ///
317    /// The tensor with the result of the equate.
318    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
319
320    /// Element-wise non-equality comparison.
321    ///
322    /// # Arguments
323    ///
324    /// * `lhs` - The left hand side tensor.
325    /// * `rhs` - The right hand side tensor.
326    ///
327    /// # Returns
328    ///
329    /// The tensor with the result of the comparison.
330    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
331        let equal_tensor = B::bool_equal(lhs, rhs);
332        B::bool_not(equal_tensor)
333    }
334
335    /// Element-wise equality comparison with a scalar.
336    ///
337    /// # Arguments
338    ///
339    /// * `lhs` - The left-hand side tensor.
340    /// * `rhs` - The right-hand side scalar.
341    ///
342    /// # Returns
343    ///
344    /// The boolean tensor with the result of the comparison.
345    fn bool_equal_elem(lhs: BoolTensor<B>, rhs: BoolElem<B>) -> BoolTensor<B>;
346
347    /// Element-wise non-equality comparison with a scalar.
348    ///
349    /// # Arguments
350    ///
351    /// * `lhs` - The left-hand side tensor.
352    /// * `rhs` - The right-hand side scalar.
353    ///
354    /// # Returns
355    ///
356    /// The boolean tensor with the result of the comparison.
357    fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: BoolElem<B>) -> BoolTensor<B> {
358        let equal_tensor = B::bool_equal_elem(lhs, rhs);
359        B::bool_not(equal_tensor)
360    }
361
362    /// Inverses boolean values.
363    ///
364    /// # Arguments
365    ///
366    /// * `tensor` - The tensor.
367    ///
368    /// # Returns
369    ///
370    /// The tensor with the result of the negation.
371    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
372
373    /// Executes the logical and (`&&`) operation on two boolean tensors.
374    ///
375    /// # Arguments
376    ///
377    /// * `lhs` - The left hand side tensor.
378    /// * `rhs` - The right hand side tensor.
379    ///
380    /// # Returns
381    ///
382    /// The tensor with the result of the logical and.
383    fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
384
385    /// Executes the logical or (`||`) operation on two boolean tensors.
386    ///
387    /// # Arguments
388    ///
389    /// * `lhs` - The left hand side tensor.
390    /// * `rhs` - The right hand side tensor.
391    ///
392    /// # Returns
393    ///
394    /// The tensor with the result of the logical or.
395    fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
396
397    /// Element-wise exclusive or.
398    ///
399    /// # Arguments
400    ///
401    /// * `lhs` - The left hand side tensor.
402    /// * `rhs` - The right hand side tensor.
403    ///
404    /// # Returns
405    ///
406    /// The tensor with the result of the comparison.
407    fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
408        Self::bool_not_equal(lhs, rhs)
409    }
410
411    /// Transposes a bool tensor.
412    ///
413    /// # Arguments
414    ///
415    /// * `tensor` - The tensor to transpose.
416    ///
417    /// # Returns
418    ///
419    /// The transposed tensor.
420    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
421        let ndims = tensor.shape().num_dims();
422        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
423    }
424
425    /// Swaps two dimensions of a bool tensor.
426    ///
427    /// # Arguments
428    ///
429    /// * `tensor` - The tensor to swap the dimensions of.
430    /// * `dim1` - The first dimension to swap.
431    /// * `dim2` - The second dimension to swap.
432    ///
433    /// # Returns
434    ///
435    /// The tensor with the dimensions swapped.
436    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
437
438    /// Permutes the dimensions of a tensor.
439    ///
440    /// # Arguments
441    ///
442    /// * `tensor` - The tensor to permute the dimensions of.
443    /// * `axes` - The new order of the dimensions.
444    /// # Returns
445    ///
446    /// The tensor with the dimensions permuted.
447    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
448
449    /// Reverse the order of elements in a tensor along the given axes.
450    ///
451    /// # Arguments
452    ///
453    /// * `tensor` - The tensor to reverse.
454    /// * `axes` - The axes to reverse.
455    ///
456    /// The tensor with the elements reversed.
457    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
458
459    /// Tests if any element in the boolean `tensor` evaluates to True.
460    ///
461    /// # Arguments
462    ///
463    /// * `tensor` - The tensor to test.
464    ///
465    /// # Returns
466    ///
467    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
468    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
469        let sum = B::int_sum(B::bool_into_int(tensor));
470        B::int_greater_elem(sum, 0.elem())
471    }
472
473    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
474    ///
475    /// # Arguments
476    ///
477    /// * `tensor` - The tensor to test.
478    /// * `dim` - The axis along which to test.
479    ///
480    /// # Returns
481    ///
482    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
483    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
484    /// evaluates to True, False otherwise.
485    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
486        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
487        B::int_greater_elem(sum, 0.elem())
488    }
489
490    /// Tests if all elements in the boolean `tensor` evaluate to True.
491    ///
492    /// # Arguments
493    ///
494    /// * `tensor` - The tensor to test.
495    ///
496    /// # Returns
497    ///
498    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
499    /// evaluate to True, False otherwise.
500    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
501        let num_elems = tensor.shape().num_elements();
502        let sum = B::int_sum(B::bool_into_int(tensor));
503        B::int_equal_elem(sum, (num_elems as i32).elem())
504    }
505
506    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
507    ///
508    /// # Arguments
509    ///
510    /// * `tensor` - The tensor to test.
511    /// * `dim` - The axis along which to test.
512    ///
513    /// # Returns
514    ///
515    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
516    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
517    /// evaluates to True, False otherwise.
518    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
519        let num_elems = tensor.shape().dims[dim];
520        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
521        B::int_equal_elem(sum, (num_elems as i32).elem())
522    }
523
524    /// Compute the indices of the elements that are non-zero, grouped by element.
525    ///
526    /// # Arguments
527    ///
528    /// * `tensor` - The input tensor.
529    ///
530    /// # Returns
531    ///
532    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
533    /// Each row contains the indices of a non-zero element.
534    fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
535        async {
536            // Size of each output tensor is variable (= number of nonzero elements in the tensor).
537            // Reading the data to count the number of truth values might cause sync but is required.
538            let device = B::bool_device(&tensor);
539            let data = B::bool_into_data(tensor)
540                .await
541                .expect("Can read the data without error");
542            argwhere_data::<B>(data, &device)
543        }
544    }
545
546    /// Broadcasts the bool `tensor` to the given `shape`.
547    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
548
549    /// Unfold windows along a dimension.
550    ///
551    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
552    /// where windows are advanced by `step` at each index.
553    ///
554    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
555    ///
556    /// # Arguments
557    ///
558    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
559    /// * `dim` - the selected dim.
560    /// * `size` - the size of each unfolded window.
561    /// * `step` - the step between each window.
562    ///
563    /// # Returns
564    ///
565    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
566    fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
567}