Skip to main content

burn_backend/backend/ops/
bool_tensor.rs

1use super::{
2    argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};
5use crate::{Backend, TensorData, TensorMetadata};
6use crate::{ExecutionError, Scalar};
7use alloc::vec::Vec;
8use burn_std::{Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16    /// Creates a new bool tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    ///
23    /// # Returns
24    ///
25    /// The boolean tensor with the given shape.
26    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
27
28    /// Creates a new bool tensor filled false.
29    ///
30    /// # Arguments
31    ///
32    /// * `shape` - The shape of the tensor.
33    /// * `device` - The device to create the tensor on.
34    ///
35    /// # Returns
36    ///
37    /// The boolean tensor filled with false.
38    fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
39
40    /// Creates a new bool tensor filled true.
41    ///
42    /// # Arguments
43    ///
44    /// * `shape` - The shape of the tensor.
45    /// * `device` - The device to create the tensor on.
46    ///
47    /// # Returns
48    ///
49    /// The boolean tensor filled with true.
50    fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
51
52    /// Converts the tensor to a data structure.
53    ///
54    /// # Arguments
55    ///
56    /// * `tensor` - The tensor.
57    ///
58    /// # Returns
59    ///
60    /// The data structure with the tensor's data.
61    fn bool_into_data(
62        tensor: BoolTensor<B>,
63    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
64
65    /// Creates a tensor from the data structure.
66    ///
67    /// # Arguments
68    ///
69    /// * `data` - The data structure.
70    /// * `device` - The device to create the tensor on.
71    ///
72    /// # Returns
73    ///
74    /// The tensor with the data.
75    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
76
77    /// Converts bool tensor to int tensor.
78    ///
79    /// # Arguments
80    ///
81    /// * `tensor` - The tensor.
82    ///
83    /// # Returns
84    ///
85    /// The int tensor with the same data as the bool tensor.
86    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
87
88    /// Converts bool tensor to float tensor.
89    ///
90    /// # Arguments
91    ///
92    /// * `tensor` - The tensor.
93    ///
94    /// # Returns
95    ///
96    /// The float tensor with the same data as the bool tensor.
97    fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
98
99    /// Gets the device of the tensor.
100    ///
101    /// # Arguments
102    ///
103    /// * `tensor` - The tensor.
104    ///
105    /// # Returns
106    ///
107    /// The device of the tensor.
108    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
109
110    /// Moves the tensor to the device.
111    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
112
113    /// Reshapes the tensor.
114    ///
115    /// # Arguments
116    ///
117    /// * `tensor` - The tensor.
118    /// * `shape` - The new shape.
119    ///
120    /// # Returns
121    ///
122    /// The tensor with the new shape.
123    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
124
125    /// Gets the values from the tensor for the given ranges.
126    ///
127    /// # Arguments
128    ///
129    /// * `tensor` - The tensor.
130    /// * `slices` - The slices specifying ranges and steps for each dimension.
131    ///
132    /// # Returns
133    ///
134    /// The tensor with the values for the given slices.
135    ///
136    /// # Note
137    ///
138    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
139    /// be passed to this method. Backend implementations do not need to handle empty slices.
140    fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
141
142    /// Sets the values in the tensor for the given ranges.
143    ///
144    /// # Arguments
145    ///
146    /// * `tensor` - The tensor.
147    /// * `ranges` - The ranges to set the values for.
148    /// * `value` - The values to set.
149    ///
150    /// # Returns
151    ///
152    /// The tensor with the values set for the given ranges.
153    ///
154    /// # Note
155    ///
156    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
157    /// high-level tensor API and will not be passed to this method. Backend implementations do
158    /// not need to handle empty slice assignments.
159    fn bool_slice_assign(
160        tensor: BoolTensor<B>,
161        slices: &[Slice],
162        value: BoolTensor<B>,
163    ) -> BoolTensor<B>;
164
165    /// Fills the tensor with values from the value tensor if the mask is true at the given
166    /// indices.
167    ///
168    /// # Arguments
169    ///
170    /// * `tensor` - The tensor.
171    /// * `mask` - The mask.
172    /// * `value` - The value tensor.
173    ///
174    /// # Returns
175    ///
176    /// The tensor with the values filled.
177    fn bool_mask_where(
178        tensor: BoolTensor<B>,
179        mask: BoolTensor<B>,
180        value: BoolTensor<B>,
181    ) -> BoolTensor<B>;
182
183    /// Fills the tensor with the given value if the mask is true at the given indices.
184    ///
185    /// # Arguments
186    ///
187    /// * `tensor` - The tensor.
188    /// * `mask` - The mask.
189    /// * `value` - The value.
190    ///
191    /// # Returns
192    ///
193    /// The tensor with the values filled.
194    fn bool_mask_fill(tensor: BoolTensor<B>, mask: BoolTensor<B>, value: Scalar) -> BoolTensor<B>;
195
196    /// Gather elements from the tensor at the given indices.
197    ///
198    /// # Arguments
199    ///
200    /// * `dim` - The dimension to gather from.
201    /// * `tensor` - The tensor.
202    /// * `indices` - The indices.
203    fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;
204
205    /// Scatter a given value to the tensor at the given indices using boolean or reduction.
206    ///
207    /// # Arguments
208    ///
209    /// * `dim` - The dimension to scatter to.
210    /// * `tensor` - The tensor.
211    /// * `indices` - The indices.
212    /// * `value` - The value.
213    ///
214    /// # Returns
215    ///
216    /// The tensor with the values scattered.
217    fn bool_scatter_or(
218        dim: usize,
219        tensor: BoolTensor<B>,
220        indices: IntTensor<B>,
221        value: BoolTensor<B>,
222    ) -> BoolTensor<B>;
223
224    /// Select tensor elements along the given dimension corresponding to the given indices.
225    ///
226    /// # Arguments
227    ///
228    /// * `tensor` - The tensor to select from.
229    /// * `dim` - The dimension to select from.
230    /// * `indices` - The indices of the elements to select.
231    ///
232    /// # Returns
233    ///
234    /// The tensor with the selected elements.
235    fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
236        // Default implementation: convert to int, select, then convert back to bool
237        let int_tensor = B::bool_into_int(tensor);
238        let selected = B::int_select(int_tensor, dim, indices);
239        B::int_equal_elem(selected, 1.into())
240    }
241
242    /// Assign the selected elements along the given dimension corresponding to the given indices
243    /// to the given value using sum reduction.
244    ///
245    /// # Arguments
246    ///
247    /// * `tensor` - The tensor to assign the values to.
248    /// * `dim` - The dimension to select from.
249    /// * `indices` - The indices of the elements to assign.
250    /// * `value` - The values to assign.
251    ///
252    /// # Returns
253    ///
254    /// The tensor with the assigned values.
255    fn bool_select_or(
256        tensor: BoolTensor<B>,
257        dim: usize,
258        indices: IntTensor<B>,
259        value: BoolTensor<B>,
260    ) -> BoolTensor<B> {
261        // Default implementation: convert to int, select_assign, then convert back to bool
262        let int_tensor = B::bool_into_int(tensor);
263        let int_values = B::bool_into_int(value);
264        let assigned = B::int_select_add(int_tensor, dim, indices, int_values);
265        // After select_assign with sum reduction, any non-zero value should be true
266        B::int_greater_elem(assigned, 0.into())
267    }
268
269    /// Repeats one dimension of the tensor a given number of times along that dimension.
270    ///
271    /// # Arguments
272    ///
273    /// * `tensor` - The tensor.
274    /// * `dim` - The dimension to repeat.
275    /// * `times` - The number of times to repeat the dimension.
276    ///
277    /// # Returns
278    ///
279    /// The tensor with the dimension repeated.
280    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
281        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
282    }
283
284    /// Concatenates the tensors along the given dimension.
285    ///
286    /// # Arguments
287    ///
288    /// * `tensors` - The tensors to concatenate.
289    /// * `dim` - The dimension to concatenate along.
290    ///
291    /// # Returns
292    ///
293    /// The tensor with the tensors concatenated along the given dimension.
294    ///
295    /// # Note
296    ///
297    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
298    /// high-level tensor API and will not be passed to this method. Backend implementations do
299    /// not need to handle empty tensors.
300    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
301        cat_with_slice_assign::<B, Bool>(tensors, dim)
302    }
303
304    /// Equates the two tensors.
305    ///
306    /// # Arguments
307    ///
308    /// * `lhs` - The left hand side tensor.
309    /// * `rhs` - The right hand side tensor.
310    ///
311    /// # Returns
312    ///
313    /// The tensor with the result of the equate.
314    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
315
316    /// Element-wise non-equality comparison.
317    ///
318    /// # Arguments
319    ///
320    /// * `lhs` - The left hand side tensor.
321    /// * `rhs` - The right hand side tensor.
322    ///
323    /// # Returns
324    ///
325    /// The tensor with the result of the comparison.
326    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
327        let equal_tensor = B::bool_equal(lhs, rhs);
328        B::bool_not(equal_tensor)
329    }
330
331    /// Element-wise equality comparison with a scalar.
332    ///
333    /// # Arguments
334    ///
335    /// * `lhs` - The left-hand side tensor.
336    /// * `rhs` - The right-hand side scalar.
337    ///
338    /// # Returns
339    ///
340    /// The boolean tensor with the result of the comparison.
341    fn bool_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B>;
342
343    /// Element-wise non-equality comparison with a scalar.
344    ///
345    /// # Arguments
346    ///
347    /// * `lhs` - The left-hand side tensor.
348    /// * `rhs` - The right-hand side scalar.
349    ///
350    /// # Returns
351    ///
352    /// The boolean tensor with the result of the comparison.
353    fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B> {
354        let equal_tensor = B::bool_equal_elem(lhs, rhs);
355        B::bool_not(equal_tensor)
356    }
357
358    /// Inverses boolean values.
359    ///
360    /// # Arguments
361    ///
362    /// * `tensor` - The tensor.
363    ///
364    /// # Returns
365    ///
366    /// The tensor with the result of the negation.
367    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
368
369    /// Executes the logical and (`&&`) operation on two boolean tensors.
370    ///
371    /// # Arguments
372    ///
373    /// * `lhs` - The left hand side tensor.
374    /// * `rhs` - The right hand side tensor.
375    ///
376    /// # Returns
377    ///
378    /// The tensor with the result of the logical and.
379    fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
380
381    /// Executes the logical or (`||`) operation on two boolean tensors.
382    ///
383    /// # Arguments
384    ///
385    /// * `lhs` - The left hand side tensor.
386    /// * `rhs` - The right hand side tensor.
387    ///
388    /// # Returns
389    ///
390    /// The tensor with the result of the logical or.
391    fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
392
393    /// Element-wise exclusive or.
394    ///
395    /// # Arguments
396    ///
397    /// * `lhs` - The left hand side tensor.
398    /// * `rhs` - The right hand side tensor.
399    ///
400    /// # Returns
401    ///
402    /// The tensor with the result of the comparison.
403    fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
404        Self::bool_not_equal(lhs, rhs)
405    }
406
407    /// Transposes a bool tensor.
408    ///
409    /// # Arguments
410    ///
411    /// * `tensor` - The tensor to transpose.
412    ///
413    /// # Returns
414    ///
415    /// The transposed tensor.
416    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
417        let ndims = tensor.shape().num_dims();
418        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
419    }
420
421    /// Swaps two dimensions of a bool tensor.
422    ///
423    /// # Arguments
424    ///
425    /// * `tensor` - The tensor to swap the dimensions of.
426    /// * `dim1` - The first dimension to swap.
427    /// * `dim2` - The second dimension to swap.
428    ///
429    /// # Returns
430    ///
431    /// The tensor with the dimensions swapped.
432    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
433
434    /// Permutes the dimensions of a tensor.
435    ///
436    /// # Arguments
437    ///
438    /// * `tensor` - The tensor to permute the dimensions of.
439    /// * `axes` - The new order of the dimensions.
440    /// # Returns
441    ///
442    /// The tensor with the dimensions permuted.
443    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
444
445    /// Reverse the order of elements in a tensor along the given axes.
446    ///
447    /// # Arguments
448    ///
449    /// * `tensor` - The tensor to reverse.
450    /// * `axes` - The axes to reverse.
451    ///
452    /// The tensor with the elements reversed.
453    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
454
455    /// Tests if any element in the boolean `tensor` evaluates to True.
456    ///
457    /// # Arguments
458    ///
459    /// * `tensor` - The tensor to test.
460    ///
461    /// # Returns
462    ///
463    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
464    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
465        let sum = B::int_sum(B::bool_into_int(tensor));
466        B::int_greater_elem(sum, 0.into())
467    }
468
469    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
470    ///
471    /// # Arguments
472    ///
473    /// * `tensor` - The tensor to test.
474    /// * `dim` - The axis along which to test.
475    ///
476    /// # Returns
477    ///
478    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
479    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
480    /// evaluates to True, False otherwise.
481    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
482        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
483        B::int_greater_elem(sum, 0.into())
484    }
485
486    /// Tests if all elements in the boolean `tensor` evaluate to True.
487    ///
488    /// # Arguments
489    ///
490    /// * `tensor` - The tensor to test.
491    ///
492    /// # Returns
493    ///
494    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
495    /// evaluate to True, False otherwise.
496    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
497        let num_elems = tensor.shape().num_elements() as i64;
498        let sum = B::int_sum(B::bool_into_int(tensor));
499        B::int_equal_elem(sum, num_elems.into())
500    }
501
502    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
503    ///
504    /// # Arguments
505    ///
506    /// * `tensor` - The tensor to test.
507    /// * `dim` - The axis along which to test.
508    ///
509    /// # Returns
510    ///
511    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
512    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
513    /// evaluates to True, False otherwise.
514    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
515        let num_elems = tensor.shape()[dim] as i64;
516        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
517        B::int_equal_elem(sum, num_elems.into())
518    }
519
520    /// Compute the indices of the elements that are non-zero, grouped by element.
521    ///
522    /// # Arguments
523    ///
524    /// * `tensor` - The input tensor.
525    ///
526    /// # Returns
527    ///
528    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
529    /// Each row contains the indices of a non-zero element.
530    fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
531        async {
532            // Size of each output tensor is variable (= number of nonzero elements in the tensor).
533            // Reading the data to count the number of truth values might cause sync but is required.
534            let device = B::bool_device(&tensor);
535            let data = B::bool_into_data(tensor)
536                .await
537                .expect("Can read the data without error");
538            argwhere_data::<B>(data, &device)
539        }
540    }
541
542    /// Broadcasts the bool `tensor` to the given `shape`.
543    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
544
545    /// Unfold windows along a dimension.
546    ///
547    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
548    /// where windows are advanced by `step` at each index.
549    ///
550    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
551    ///
552    /// # Arguments
553    ///
554    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
555    /// * `dim` - the selected dim.
556    /// * `size` - the size of each unfolded window.
557    /// * `step` - the step between each window.
558    ///
559    /// # Returns
560    ///
561    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
562    fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
563}