Skip to main content

burn_backend/backend/ops/
bool_tensor.rs

1use super::{
2    argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};
5use crate::{Backend, TensorData, TensorMetadata, get_device_settings};
6use crate::{ExecutionError, Scalar};
7use alloc::vec::Vec;
8use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16    /// Creates a new bool tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    /// * `dtype` - The target data type.
23    ///
24    /// # Returns
25    ///
26    /// The boolean tensor with the given shape.
27    fn bool_empty(shape: Shape, device: &Device<B>, dtype: BoolDType) -> BoolTensor<B>;
28
29    /// Creates a new bool tensor filled false.
30    ///
31    /// # Arguments
32    ///
33    /// * `shape` - The shape of the tensor.
34    /// * `device` - The device to create the tensor on.
35    /// * `dtype` - The target data type.
36    ///
37    /// # Returns
38    ///
39    /// The boolean tensor filled with false.
40    fn bool_zeros(shape: Shape, device: &Device<B>, dtype: BoolDType) -> BoolTensor<B>;
41
42    /// Creates a new bool tensor filled true.
43    ///
44    /// # Arguments
45    ///
46    /// * `shape` - The shape of the tensor.
47    /// * `device` - The device to create the tensor on.
48    /// * `dtype` - The target data type.
49    ///
50    /// # Returns
51    ///
52    /// The boolean tensor filled with true.
53    fn bool_ones(shape: Shape, device: &Device<B>, dtype: BoolDType) -> BoolTensor<B>;
54
55    /// Converts the tensor to a data structure.
56    ///
57    /// # Arguments
58    ///
59    /// * `tensor` - The tensor.
60    ///
61    /// # Returns
62    ///
63    /// The data structure with the tensor's data.
64    fn bool_into_data(
65        tensor: BoolTensor<B>,
66    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
67
68    /// Creates a tensor from the data structure.
69    ///
70    /// # Arguments
71    ///
72    /// * `data` - The data structure.
73    /// * `device` - The device to create the tensor on.
74    ///
75    /// # Returns
76    ///
77    /// The tensor with the data.
78    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
79
80    /// Converts bool tensor to int tensor.
81    ///
82    /// # Arguments
83    ///
84    /// * `tensor` - The tensor.
85    /// * `out_dtype` - The output tensor dtype.
86    ///
87    /// # Returns
88    ///
89    /// The int tensor with the same data as the bool tensor.
90    fn bool_into_int(tensor: BoolTensor<B>, out_dtype: IntDType) -> IntTensor<B>;
91
92    /// Converts bool tensor to float tensor.
93    ///
94    /// # Arguments
95    ///
96    /// * `tensor` - The tensor.
97    /// * `out_dtype` - The output tensor dtype.
98    ///
99    /// # Returns
100    ///
101    /// The float tensor with the same data as the bool tensor.
102    fn bool_into_float(tensor: BoolTensor<B>, out_dtype: FloatDType) -> FloatTensor<B>;
103
104    /// Gets the device of the tensor.
105    ///
106    /// # Arguments
107    ///
108    /// * `tensor` - The tensor.
109    ///
110    /// # Returns
111    ///
112    /// The device of the tensor.
113    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
114
115    /// Moves the tensor to the device.
116    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
117
118    /// Reshapes the tensor.
119    ///
120    /// # Arguments
121    ///
122    /// * `tensor` - The tensor.
123    /// * `shape` - The new shape.
124    ///
125    /// # Returns
126    ///
127    /// The tensor with the new shape.
128    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
129
130    /// Gets the values from the tensor for the given ranges.
131    ///
132    /// # Arguments
133    ///
134    /// * `tensor` - The tensor.
135    /// * `slices` - The slices specifying ranges and steps for each dimension.
136    ///
137    /// # Returns
138    ///
139    /// The tensor with the values for the given slices.
140    ///
141    /// # Note
142    ///
143    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
144    /// be passed to this method. Backend implementations do not need to handle empty slices.
145    fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
146
147    /// Sets the values in the tensor for the given ranges.
148    ///
149    /// # Arguments
150    ///
151    /// * `tensor` - The tensor.
152    /// * `ranges` - The ranges to set the values for.
153    /// * `value` - The values to set.
154    ///
155    /// # Returns
156    ///
157    /// The tensor with the values set for the given ranges.
158    ///
159    /// # Note
160    ///
161    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
162    /// high-level tensor API and will not be passed to this method. Backend implementations do
163    /// not need to handle empty slice assignments.
164    fn bool_slice_assign(
165        tensor: BoolTensor<B>,
166        slices: &[Slice],
167        value: BoolTensor<B>,
168    ) -> BoolTensor<B>;
169
170    /// Fills the tensor with values from the value tensor if the mask is true at the given
171    /// indices.
172    ///
173    /// # Arguments
174    ///
175    /// * `tensor` - The tensor.
176    /// * `mask` - The mask.
177    /// * `value` - The value tensor.
178    ///
179    /// # Returns
180    ///
181    /// The tensor with the values filled.
182    fn bool_mask_where(
183        tensor: BoolTensor<B>,
184        mask: BoolTensor<B>,
185        value: BoolTensor<B>,
186    ) -> BoolTensor<B>;
187
188    /// Fills the tensor with the given value if the mask is true at the given indices.
189    ///
190    /// # Arguments
191    ///
192    /// * `tensor` - The tensor.
193    /// * `mask` - The mask.
194    /// * `value` - The value.
195    ///
196    /// # Returns
197    ///
198    /// The tensor with the values filled.
199    fn bool_mask_fill(tensor: BoolTensor<B>, mask: BoolTensor<B>, value: Scalar) -> BoolTensor<B>;
200
201    /// Gather elements from the tensor at the given indices.
202    ///
203    /// # Arguments
204    ///
205    /// * `dim` - The dimension to gather from.
206    /// * `tensor` - The tensor.
207    /// * `indices` - The indices.
208    fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;
209
210    /// Scatter a given value to the tensor at the given indices using boolean or reduction.
211    ///
212    /// # Arguments
213    ///
214    /// * `dim` - The dimension to scatter to.
215    /// * `tensor` - The tensor.
216    /// * `indices` - The indices.
217    /// * `value` - The value.
218    ///
219    /// # Returns
220    ///
221    /// The tensor with the values scattered.
222    fn bool_scatter_or(
223        dim: usize,
224        tensor: BoolTensor<B>,
225        indices: IntTensor<B>,
226        value: BoolTensor<B>,
227    ) -> BoolTensor<B>;
228
229    /// Select tensor elements along the given dimension corresponding to the given indices.
230    ///
231    /// # Arguments
232    ///
233    /// * `tensor` - The tensor to select from.
234    /// * `dim` - The dimension to select from.
235    /// * `indices` - The indices of the elements to select.
236    ///
237    /// # Returns
238    ///
239    /// The tensor with the selected elements.
240    fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B>;
241
242    /// Assign the selected elements along the given dimension corresponding to the given indices
243    /// to the given value using sum reduction.
244    ///
245    /// # Arguments
246    ///
247    /// * `tensor` - The tensor to assign the values to.
248    /// * `dim` - The dimension to select from.
249    /// * `indices` - The indices of the elements to assign.
250    /// * `value` - The values to assign.
251    ///
252    /// # Returns
253    ///
254    /// The tensor with the assigned values.
255    fn bool_select_or(
256        tensor: BoolTensor<B>,
257        dim: usize,
258        indices: IntTensor<B>,
259        value: BoolTensor<B>,
260    ) -> BoolTensor<B>;
261
262    /// Repeats one dimension of the tensor a given number of times along that dimension.
263    ///
264    /// # Arguments
265    ///
266    /// * `tensor` - The tensor.
267    /// * `dim` - The dimension to repeat.
268    /// * `times` - The number of times to repeat the dimension.
269    ///
270    /// # Returns
271    ///
272    /// The tensor with the dimension repeated.
273    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
274        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
275    }
276
277    /// Concatenates the tensors along the given dimension.
278    ///
279    /// # Arguments
280    ///
281    /// * `tensors` - The tensors to concatenate.
282    /// * `dim` - The dimension to concatenate along.
283    ///
284    /// # Returns
285    ///
286    /// The tensor with the tensors concatenated along the given dimension.
287    ///
288    /// # Note
289    ///
290    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
291    /// high-level tensor API and will not be passed to this method. Backend implementations do
292    /// not need to handle empty tensors.
293    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
294        cat_with_slice_assign::<B, Bool>(tensors, dim)
295    }
296
297    /// Equates the two tensors.
298    ///
299    /// # Arguments
300    ///
301    /// * `lhs` - The left hand side tensor.
302    /// * `rhs` - The right hand side tensor.
303    ///
304    /// # Returns
305    ///
306    /// The tensor with the result of the equate.
307    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
308
309    /// Element-wise non-equality comparison.
310    ///
311    /// # Arguments
312    ///
313    /// * `lhs` - The left hand side tensor.
314    /// * `rhs` - The right hand side tensor.
315    ///
316    /// # Returns
317    ///
318    /// The tensor with the result of the comparison.
319    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
320        let equal_tensor = B::bool_equal(lhs, rhs);
321        B::bool_not(equal_tensor)
322    }
323
324    /// Element-wise equality comparison with a scalar.
325    ///
326    /// # Arguments
327    ///
328    /// * `lhs` - The left-hand side tensor.
329    /// * `rhs` - The right-hand side scalar.
330    ///
331    /// # Returns
332    ///
333    /// The boolean tensor with the result of the comparison.
334    fn bool_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B>;
335
336    /// Element-wise non-equality comparison with a scalar.
337    ///
338    /// # Arguments
339    ///
340    /// * `lhs` - The left-hand side tensor.
341    /// * `rhs` - The right-hand side scalar.
342    ///
343    /// # Returns
344    ///
345    /// The boolean tensor with the result of the comparison.
346    fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B> {
347        let equal_tensor = B::bool_equal_elem(lhs, rhs);
348        B::bool_not(equal_tensor)
349    }
350
351    /// Inverses boolean values.
352    ///
353    /// # Arguments
354    ///
355    /// * `tensor` - The tensor.
356    ///
357    /// # Returns
358    ///
359    /// The tensor with the result of the negation.
360    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
361
362    /// Executes the logical and (`&&`) operation on two boolean tensors.
363    ///
364    /// # Arguments
365    ///
366    /// * `lhs` - The left hand side tensor.
367    /// * `rhs` - The right hand side tensor.
368    ///
369    /// # Returns
370    ///
371    /// The tensor with the result of the logical and.
372    fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
373
374    /// Executes the logical or (`||`) operation on two boolean tensors.
375    ///
376    /// # Arguments
377    ///
378    /// * `lhs` - The left hand side tensor.
379    /// * `rhs` - The right hand side tensor.
380    ///
381    /// # Returns
382    ///
383    /// The tensor with the result of the logical or.
384    fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
385
386    /// Element-wise exclusive or.
387    ///
388    /// # Arguments
389    ///
390    /// * `lhs` - The left hand side tensor.
391    /// * `rhs` - The right hand side tensor.
392    ///
393    /// # Returns
394    ///
395    /// The tensor with the result of the comparison.
396    fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
397        Self::bool_not_equal(lhs, rhs)
398    }
399
400    /// Transposes a bool tensor.
401    ///
402    /// # Arguments
403    ///
404    /// * `tensor` - The tensor to transpose.
405    ///
406    /// # Returns
407    ///
408    /// The transposed tensor.
409    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
410        let ndims = tensor.shape().num_dims();
411        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
412    }
413
414    /// Swaps two dimensions of a bool tensor.
415    ///
416    /// # Arguments
417    ///
418    /// * `tensor` - The tensor to swap the dimensions of.
419    /// * `dim1` - The first dimension to swap.
420    /// * `dim2` - The second dimension to swap.
421    ///
422    /// # Returns
423    ///
424    /// The tensor with the dimensions swapped.
425    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
426
427    /// Permutes the dimensions of a tensor.
428    ///
429    /// # Arguments
430    ///
431    /// * `tensor` - The tensor to permute the dimensions of.
432    /// * `axes` - The new order of the dimensions.
433    /// # Returns
434    ///
435    /// The tensor with the dimensions permuted.
436    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
437
438    /// Reverse the order of elements in a tensor along the given axes.
439    ///
440    /// # Arguments
441    ///
442    /// * `tensor` - The tensor to reverse.
443    /// * `axes` - The axes to reverse.
444    ///
445    /// The tensor with the elements reversed.
446    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
447
448    /// Tests if any element in the boolean `tensor` evaluates to True.
449    ///
450    /// # Arguments
451    ///
452    /// * `tensor` - The tensor to test.
453    ///
454    /// # Returns
455    ///
456    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
457    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
458        let dtype = tensor.dtype();
459        let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
460        let sum = B::int_sum(B::bool_into_int(tensor, int_dtype));
461        B::int_greater_elem(sum, 0.into(), dtype.into())
462    }
463
464    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
465    ///
466    /// # Arguments
467    ///
468    /// * `tensor` - The tensor to test.
469    /// * `dim` - The axis along which to test.
470    ///
471    /// # Returns
472    ///
473    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
474    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
475    /// evaluates to True, False otherwise.
476    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
477        let dtype = tensor.dtype();
478        let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
479        let sum = B::int_sum_dim(B::bool_into_int(tensor, int_dtype), dim);
480        B::int_greater_elem(sum, 0.into(), dtype.into())
481    }
482
483    /// Tests if all elements in the boolean `tensor` evaluate to True.
484    ///
485    /// # Arguments
486    ///
487    /// * `tensor` - The tensor to test.
488    ///
489    /// # Returns
490    ///
491    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
492    /// evaluate to True, False otherwise.
493    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
494        let dtype = tensor.dtype();
495        let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
496        let num_elems = tensor.shape().num_elements() as i64;
497        let sum = B::int_sum(B::bool_into_int(tensor, int_dtype));
498        B::int_equal_elem(sum, num_elems.into(), dtype.into())
499    }
500
501    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
502    ///
503    /// # Arguments
504    ///
505    /// * `tensor` - The tensor to test.
506    /// * `dim` - The axis along which to test.
507    ///
508    /// # Returns
509    ///
510    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
511    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
512    /// evaluates to True, False otherwise.
513    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
514        let dtype = tensor.dtype();
515        let int_dtype = get_device_settings::<B>(&B::bool_device(&tensor)).int_dtype;
516        let num_elems = tensor.shape()[dim] as i64;
517        let sum = B::int_sum_dim(B::bool_into_int(tensor, int_dtype), dim);
518        B::int_equal_elem(sum, num_elems.into(), dtype.into())
519    }
520
521    /// Compute the indices of the elements that are non-zero, grouped by element.
522    ///
523    /// # Arguments
524    ///
525    /// * `tensor` - The input tensor.
526    /// * `out_dtype` - The output tensor dtype.
527    ///
528    /// # Returns
529    ///
530    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
531    /// Each row contains the indices of a non-zero element.
532    fn bool_argwhere(
533        tensor: BoolTensor<B>,
534        out_dtype: IntDType,
535    ) -> impl Future<Output = IntTensor<B>> + 'static + Send {
536        async move {
537            // Size of each output tensor is variable (= number of nonzero elements in the tensor).
538            // Reading the data to count the number of truth values might cause sync but is required.
539            let device = B::bool_device(&tensor);
540            let data = B::bool_into_data(tensor)
541                .await
542                .expect("Can read the data without error");
543            argwhere_data::<B>(data, &device, out_dtype)
544        }
545    }
546
547    /// Broadcasts the bool `tensor` to the given `shape`.
548    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
549
550    /// Unfold windows along a dimension.
551    ///
552    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
553    /// where windows are advanced by `step` at each index.
554    ///
555    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
556    ///
557    /// # Arguments
558    ///
559    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
560    /// * `dim` - the selected dim.
561    /// * `size` - the size of each unfolded window.
562    /// * `step` - the step between each window.
563    ///
564    /// # Returns
565    ///
566    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
567    fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
568}