burn_backend/backend/ops/
bool_tensor.rs

1use super::{
2    argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::ExecutionError;
5use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};
6use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
7use alloc::vec::Vec;
8use burn_std::{Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16    /// Creates a new bool tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    ///
23    /// # Returns
24    ///
25    /// The boolean tensor with the given shape.
26    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
27
28    /// Creates a new bool tensor filled false.
29    ///
30    /// # Arguments
31    ///
32    /// * `shape` - The shape of the tensor.
33    /// * `device` - The device to create the tensor on.
34    ///
35    /// # Returns
36    ///
37    /// The boolean tensor filled with false.
38    fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
39
40    /// Creates a new bool tensor filled true.
41    ///
42    /// # Arguments
43    ///
44    /// * `shape` - The shape of the tensor.
45    /// * `device` - The device to create the tensor on.
46    ///
47    /// # Returns
48    ///
49    /// The boolean tensor filled with true.
50    fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
51
52    /// Converts the tensor to a data structure.
53    ///
54    /// # Arguments
55    ///
56    /// * `tensor` - The tensor.
57    ///
58    /// # Returns
59    ///
60    /// The data structure with the tensor's data.
61    fn bool_into_data(
62        tensor: BoolTensor<B>,
63    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
64
65    /// Creates a tensor from the data structure.
66    ///
67    /// # Arguments
68    ///
69    /// * `data` - The data structure.
70    /// * `device` - The device to create the tensor on.
71    ///
72    /// # Returns
73    ///
74    /// The tensor with the data.
75    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
76
77    /// Converts bool tensor to int tensor.
78    ///
79    /// # Arguments
80    ///
81    /// * `tensor` - The tensor.
82    ///
83    /// # Returns
84    ///
85    /// The int tensor with the same data as the bool tensor.
86    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
87
88    /// Converts bool tensor to float tensor.
89    ///
90    /// # Arguments
91    ///
92    /// * `tensor` - The tensor.
93    ///
94    /// # Returns
95    ///
96    /// The float tensor with the same data as the bool tensor.
97    fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
98
99    /// Gets the device of the tensor.
100    ///
101    /// # Arguments
102    ///
103    /// * `tensor` - The tensor.
104    ///
105    /// # Returns
106    ///
107    /// The device of the tensor.
108    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
109
110    /// Moves the tensor to the device.
111    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
112
113    /// Reshapes the tensor.
114    ///
115    /// # Arguments
116    ///
117    /// * `tensor` - The tensor.
118    /// * `shape` - The new shape.
119    ///
120    /// # Returns
121    ///
122    /// The tensor with the new shape.
123    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
124
125    /// Gets the values from the tensor for the given ranges.
126    ///
127    /// # Arguments
128    ///
129    /// * `tensor` - The tensor.
130    /// * `slices` - The slices specifying ranges and steps for each dimension.
131    ///
132    /// # Returns
133    ///
134    /// The tensor with the values for the given slices.
135    ///
136    /// # Note
137    ///
138    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
139    /// be passed to this method. Backend implementations do not need to handle empty slices.
140    fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
141
142    /// Sets the values in the tensor for the given ranges.
143    ///
144    /// # Arguments
145    ///
146    /// * `tensor` - The tensor.
147    /// * `ranges` - The ranges to set the values for.
148    /// * `value` - The values to set.
149    ///
150    /// # Returns
151    ///
152    /// The tensor with the values set for the given ranges.
153    ///
154    /// # Note
155    ///
156    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
157    /// high-level tensor API and will not be passed to this method. Backend implementations do
158    /// not need to handle empty slice assignments.
159    fn bool_slice_assign(
160        tensor: BoolTensor<B>,
161        slices: &[Slice],
162        value: BoolTensor<B>,
163    ) -> BoolTensor<B>;
164
165    /// Select tensor elements along the given dimension corresponding to the given indices.
166    ///
167    /// # Arguments
168    ///
169    /// * `tensor` - The tensor to select from.
170    /// * `dim` - The dimension to select from.
171    /// * `indices` - The indices of the elements to select.
172    ///
173    /// # Returns
174    ///
175    /// The tensor with the selected elements.
176    fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
177        // Default implementation: convert to int, select, then convert back to bool
178        let int_tensor = B::bool_into_int(tensor);
179        let selected = B::int_select(int_tensor, dim, indices);
180        B::int_equal_elem(selected, 1_i32.elem())
181    }
182
183    /// Assign the selected elements along the given dimension corresponding to the given indices
184    /// to the given value using sum reduction.
185    ///
186    /// # Arguments
187    ///
188    /// * `tensor` - The tensor to assign the values to.
189    /// * `dim` - The dimension to select from.
190    /// * `indices` - The indices of the elements to assign.
191    /// * `value` - The values to assign.
192    ///
193    /// # Returns
194    ///
195    /// The tensor with the assigned values.
196    fn bool_select_add(
197        tensor: BoolTensor<B>,
198        dim: usize,
199        indices: IntTensor<B>,
200        value: BoolTensor<B>,
201    ) -> BoolTensor<B> {
202        // Default implementation: convert to int, select_assign, then convert back to bool
203        let int_tensor = B::bool_into_int(tensor);
204        let int_values = B::bool_into_int(value);
205        let assigned = B::int_select_add(int_tensor, dim, indices, int_values);
206        // After select_assign with sum reduction, any non-zero value should be true
207        B::int_greater_elem(assigned, 0_i32.elem())
208    }
209
210    /// Repeats one dimension of the tensor a given number of times along that dimension.
211    ///
212    /// # Arguments
213    ///
214    /// * `tensor` - The tensor.
215    /// * `dim` - The dimension to repeat.
216    /// * `times` - The number of times to repeat the dimension.
217    ///
218    /// # Returns
219    ///
220    /// The tensor with the dimension repeated.
221    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
222        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
223    }
224
225    /// Concatenates the tensors along the given dimension.
226    ///
227    /// # Arguments
228    ///
229    /// * `tensors` - The tensors to concatenate.
230    /// * `dim` - The dimension to concatenate along.
231    ///
232    /// # Returns
233    ///
234    /// The tensor with the tensors concatenated along the given dimension.
235    ///
236    /// # Note
237    ///
238    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
239    /// high-level tensor API and will not be passed to this method. Backend implementations do
240    /// not need to handle empty tensors.
241    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
242        cat_with_slice_assign::<B, Bool>(tensors, dim)
243    }
244
245    /// Equates the two tensors.
246    ///
247    /// # Arguments
248    ///
249    /// * `lhs` - The left hand side tensor.
250    /// * `rhs` - The right hand side tensor.
251    ///
252    /// # Returns
253    ///
254    /// The tensor with the result of the equate.
255    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
256
257    /// Element-wise non-equality comparison.
258    ///
259    /// # Arguments
260    ///
261    /// * `lhs` - The left hand side tensor.
262    /// * `rhs` - The right hand side tensor.
263    ///
264    /// # Returns
265    ///
266    /// The tensor with the result of the comparison.
267    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
268        let equal_tensor = B::bool_equal(lhs, rhs);
269        B::bool_not(equal_tensor)
270    }
271
272    /// Inverses boolean values.
273    ///
274    /// # Arguments
275    ///
276    /// * `tensor` - The tensor.
277    ///
278    /// # Returns
279    ///
280    /// The tensor with the result of the negation.
281    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
282
283    /// Executes the logical and (`&&`) operation on two boolean tensors.
284    ///
285    /// # Arguments
286    ///
287    /// * `lhs` - The left hand side tensor.
288    /// * `rhs` - The right hand side tensor.
289    ///
290    /// # Returns
291    ///
292    /// The tensor with the result of the logical and.
293    fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
294
295    /// Executes the logical or (`||`) operation on two boolean tensors.
296    ///
297    /// # Arguments
298    ///
299    /// * `lhs` - The left hand side tensor.
300    /// * `rhs` - The right hand side tensor.
301    ///
302    /// # Returns
303    ///
304    /// The tensor with the result of the logical or.
305    fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
306
307    /// Element-wise exclusive or.
308    ///
309    /// # Arguments
310    ///
311    /// * `lhs` - The left hand side tensor.
312    /// * `rhs` - The right hand side tensor.
313    ///
314    /// # Returns
315    ///
316    /// The tensor with the result of the comparison.
317    fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
318        Self::bool_not_equal(lhs, rhs)
319    }
320
321    /// Transposes a bool tensor.
322    ///
323    /// # Arguments
324    ///
325    /// * `tensor` - The tensor to transpose.
326    ///
327    /// # Returns
328    ///
329    /// The transposed tensor.
330    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
331        let ndims = tensor.shape().num_dims();
332        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
333    }
334
335    /// Swaps two dimensions of a bool tensor.
336    ///
337    /// # Arguments
338    ///
339    /// * `tensor` - The tensor to swap the dimensions of.
340    /// * `dim1` - The first dimension to swap.
341    /// * `dim2` - The second dimension to swap.
342    ///
343    /// # Returns
344    ///
345    /// The tensor with the dimensions swapped.
346    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
347
348    /// Permutes the dimensions of a tensor.
349    ///
350    /// # Arguments
351    ///
352    /// * `tensor` - The tensor to permute the dimensions of.
353    /// * `axes` - The new order of the dimensions.
354    /// # Returns
355    ///
356    /// The tensor with the dimensions permuted.
357    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
358
359    /// Reverse the order of elements in a tensor along the given axes.
360    ///
361    /// # Arguments
362    ///
363    /// * `tensor` - The tensor to reverse.
364    /// * `axes` - The axes to reverse.
365    ///
366    /// The tensor with the elements reversed.
367    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
368
369    /// Tests if any element in the boolean `tensor` evaluates to True.
370    ///
371    /// # Arguments
372    ///
373    /// * `tensor` - The tensor to test.
374    ///
375    /// # Returns
376    ///
377    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
378    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
379        let sum = B::int_sum(B::bool_into_int(tensor));
380        B::int_greater_elem(sum, 0.elem())
381    }
382
383    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
384    ///
385    /// # Arguments
386    ///
387    /// * `tensor` - The tensor to test.
388    /// * `dim` - The axis along which to test.
389    ///
390    /// # Returns
391    ///
392    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
393    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
394    /// evaluates to True, False otherwise.
395    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
396        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
397        B::int_greater_elem(sum, 0.elem())
398    }
399
400    /// Tests if all elements in the boolean `tensor` evaluate to True.
401    ///
402    /// # Arguments
403    ///
404    /// * `tensor` - The tensor to test.
405    ///
406    /// # Returns
407    ///
408    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
409    /// evaluate to True, False otherwise.
410    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
411        let num_elems = tensor.shape().num_elements();
412        let sum = B::int_sum(B::bool_into_int(tensor));
413        B::int_equal_elem(sum, (num_elems as i32).elem())
414    }
415
416    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
417    ///
418    /// # Arguments
419    ///
420    /// * `tensor` - The tensor to test.
421    /// * `dim` - The axis along which to test.
422    ///
423    /// # Returns
424    ///
425    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
426    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
427    /// evaluates to True, False otherwise.
428    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
429        let num_elems = tensor.shape().dims[dim];
430        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
431        B::int_equal_elem(sum, (num_elems as i32).elem())
432    }
433
434    /// Compute the indices of the elements that are non-zero, grouped by element.
435    ///
436    /// # Arguments
437    ///
438    /// * `tensor` - The input tensor.
439    ///
440    /// # Returns
441    ///
442    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
443    /// Each row contains the indices of a non-zero element.
444    fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
445        async {
446            // Size of each output tensor is variable (= number of nonzero elements in the tensor).
447            // Reading the data to count the number of truth values might cause sync but is required.
448            let device = B::bool_device(&tensor);
449            let data = B::bool_into_data(tensor)
450                .await
451                .expect("Can read the data without error");
452            argwhere_data::<B>(data, &device)
453        }
454    }
455
456    /// Broadcasts the bool `tensor` to the given `shape`.
457    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
458
459    /// Unfold windows along a dimension.
460    ///
461    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
462    /// where windows are advanced by `step` at each index.
463    ///
464    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
465    ///
466    /// # Arguments
467    ///
468    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
469    /// * `dim` - the selected dim.
470    /// * `size` - the size of each unfolded window.
471    /// * `step` - the step between each window.
472    ///
473    /// # Returns
474    ///
475    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
476    fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
477}