burn_tensor/tensor/ops/
bool_tensor.rs

1use super::{
2    BoolTensor, Device, FloatTensor, IntTensor, cat::cat_with_slice_assign,
3    repeat_dim::repeat_with_slice_assign,
4};
5use crate::{
6    Bool, ElementConversion, TensorData, TensorMetadata, argwhere_data, backend::Backend,
7    tensor::Shape,
8};
9use alloc::vec::Vec;
10use core::future::Future;
11
12/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
13/// for documentation on each function.
14pub trait BoolTensorOps<B: Backend> {
15    /// Creates a new bool tensor.
16    ///
17    /// # Arguments
18    ///
19    /// * `shape` - The shape of the tensor.
20    /// * `device` - The device to create the tensor on.
21    ///
22    /// # Returns
23    ///
24    /// The boolean tensor with the given shape.
25    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
26
27    /// Creates a new bool tensor filled false.
28    ///
29    /// # Arguments
30    ///
31    /// * `shape` - The shape of the tensor.
32    /// * `device` - The device to create the tensor on.
33    ///
34    /// # Returns
35    ///
36    /// The boolean tensor filled with false.
37    fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
38
39    /// Creates a new bool tensor filled true.
40    ///
41    /// # Arguments
42    ///
43    /// * `shape` - The shape of the tensor.
44    /// * `device` - The device to create the tensor on.
45    ///
46    /// # Returns
47    ///
48    /// The boolean tensor filled with true.
49    fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
50
51    /// Converts the tensor to a data structure.
52    ///
53    /// # Arguments
54    ///
55    /// * `tensor` - The tensor.
56    ///
57    /// # Returns
58    ///
59    /// The data structure with the tensor's data.
60    fn bool_into_data(tensor: BoolTensor<B>) -> impl Future<Output = TensorData> + Send;
61
62    /// Creates a tensor from the data structure.
63    ///
64    /// # Arguments
65    ///
66    /// * `data` - The data structure.
67    /// * `device` - The device to create the tensor on.
68    ///
69    /// # Returns
70    ///
71    /// The tensor with the data.
72    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
73
74    /// Converts bool tensor to int tensor.
75    ///
76    /// # Arguments
77    ///
78    /// * `tensor` - The tensor.
79    ///
80    /// # Returns
81    ///
82    /// The int tensor with the same data as the bool tensor.
83    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
84
85    /// Converts bool tensor to float tensor.
86    ///
87    /// # Arguments
88    ///
89    /// * `tensor` - The tensor.
90    ///
91    /// # Returns
92    ///
93    /// The float tensor with the same data as the bool tensor.
94    fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
95
96    /// Gets the device of the tensor.
97    ///
98    /// # Arguments
99    ///
100    /// * `tensor` - The tensor.
101    ///
102    /// # Returns
103    ///
104    /// The device of the tensor.
105    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
106
107    /// Moves the tensor to the device.
108    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
109
110    /// Reshapes the tensor.
111    ///
112    /// # Arguments
113    ///
114    /// * `tensor` - The tensor.
115    /// * `shape` - The new shape.
116    ///
117    /// # Returns
118    ///
119    /// The tensor with the new shape.
120    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
121
122    /// Gets the values from the tensor for the given ranges.
123    ///
124    /// # Arguments
125    ///
126    /// * `tensor` - The tensor.
127    /// * `slices` - The slices specifying ranges and steps for each dimension.
128    ///
129    /// # Returns
130    ///
131    /// The tensor with the values for the given slices.
132    fn bool_slice(tensor: BoolTensor<B>, slices: &[crate::Slice]) -> BoolTensor<B>;
133
134    /// Sets the values in the tensor for the given ranges.
135    ///
136    /// # Arguments
137    ///
138    /// * `tensor` - The tensor.
139    /// * `ranges` - The ranges to set the values for.
140    /// * `value` - The values to set.
141    ///
142    /// # Returns
143    ///
144    /// The tensor with the values set for the given ranges.
145    fn bool_slice_assign(
146        tensor: BoolTensor<B>,
147        slices: &[crate::Slice],
148        value: BoolTensor<B>,
149    ) -> BoolTensor<B>;
150
151    /// Select tensor elements along the given dimension corresponding to the given indices.
152    ///
153    /// # Arguments
154    ///
155    /// * `tensor` - The tensor to select from.
156    /// * `dim` - The dimension to select from.
157    /// * `indices` - The indices of the elements to select.
158    ///
159    /// # Returns
160    ///
161    /// The tensor with the selected elements.
162    fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
163        // Default implementation: convert to int, select, then convert back to bool
164        let int_tensor = B::bool_into_int(tensor);
165        let selected = B::int_select(int_tensor, dim, indices);
166        B::int_equal_elem(selected, 1_i32.elem())
167    }
168
169    /// Assign the selected elements along the given dimension corresponding to the given indices
170    /// to the given value.
171    ///
172    /// # Arguments
173    ///
174    /// * `tensor` - The tensor to assign the values to.
175    /// * `dim` - The dimension to select from.
176    /// * `indices` - The indices of the elements to assign.
177    /// * `value` - The values to assign.
178    ///
179    /// # Returns
180    ///
181    /// The tensor with the assigned values.
182    fn bool_select_assign(
183        tensor: BoolTensor<B>,
184        dim: usize,
185        indices: IntTensor<B>,
186        value: BoolTensor<B>,
187    ) -> BoolTensor<B> {
188        // Default implementation: convert to int, select_assign, then convert back to bool
189        let int_tensor = B::bool_into_int(tensor);
190        let int_values = B::bool_into_int(value);
191        let assigned = B::int_select_assign(int_tensor, dim, indices, int_values);
192        // After select_assign with sum reduction, any non-zero value should be true
193        B::int_greater_elem(assigned, 0_i32.elem())
194    }
195
196    /// Repeats one dimension of the tensor a given number of times along that dimension.
197    ///
198    /// # Arguments
199    ///
200    /// * `tensor` - The tensor.
201    /// * `dim` - The dimension to repeat.
202    /// * `times` - The number of times to repeat the dimension.
203    ///
204    /// # Returns
205    ///
206    /// The tensor with the dimension repeated.
207    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
208        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
209    }
210
211    /// Concatenates the tensors along the given dimension.
212    ///
213    /// # Arguments
214    ///
215    /// * `tensors` - The tensors to concatenate.
216    /// * `dim` - The dimension to concatenate along.
217    ///
218    /// # Returns
219    ///
220    /// The tensor with the tensors concatenated along the given dimension.
221    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
222        cat_with_slice_assign::<B, Bool>(tensors, dim)
223    }
224
225    /// Equates the two tensors.
226    ///
227    /// # Arguments
228    ///
229    /// * `lhs` - The left hand side tensor.
230    /// * `rhs` - The right hand side tensor.
231    ///
232    /// # Returns
233    ///
234    /// The tensor with the result of the equate.
235    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
236
237    /// Element-wise non-equality comparison.
238    ///
239    /// # Arguments
240    ///
241    /// * `lhs` - The left hand side tensor.
242    /// * `rhs` - The right hand side tensor.
243    ///
244    /// # Returns
245    ///
246    /// The tensor with the result of the comparison.
247    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
248        let equal_tensor = B::bool_equal(lhs, rhs);
249        B::bool_not(equal_tensor)
250    }
251
252    /// Inverses boolean values.
253    ///
254    /// # Arguments
255    ///
256    /// * `tensor` - The tensor.
257    ///
258    /// # Returns
259    ///
260    /// The tensor with the result of the negation.
261    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
262
263    /// Executes the logical and (`&&`) operation on two boolean tensors.
264    ///
265    /// # Arguments
266    ///
267    /// * `lhs` - The left hand side tensor.
268    /// * `rhs` - The right hand side tensor.
269    ///
270    /// # Returns
271    ///
272    /// The tensor with the result of the logical and.
273    fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
274
275    /// Executes the logical or (`||`) operation on two boolean tensors.
276    ///
277    /// # Arguments
278    ///
279    /// * `lhs` - The left hand side tensor.
280    /// * `rhs` - The right hand side tensor.
281    ///
282    /// # Returns
283    ///
284    /// The tensor with the result of the logical or.
285    fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
286
287    /// Element-wise exclusive or.
288    ///
289    /// # Arguments
290    ///
291    /// * `lhs` - The left hand side tensor.
292    /// * `rhs` - The right hand side tensor.
293    ///
294    /// # Returns
295    ///
296    /// The tensor with the result of the comparison.
297    fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
298        Self::bool_not_equal(lhs, rhs)
299    }
300
301    /// Transposes a bool tensor.
302    ///
303    /// # Arguments
304    ///
305    /// * `tensor` - The tensor to transpose.
306    ///
307    /// # Returns
308    ///
309    /// The transposed tensor.
310    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
311        let ndims = tensor.shape().num_dims();
312        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
313    }
314
315    /// Swaps two dimensions of a bool tensor.
316    ///
317    /// # Arguments
318    ///
319    /// * `tensor` - The tensor to swap the dimensions of.
320    /// * `dim1` - The first dimension to swap.
321    /// * `dim2` - The second dimension to swap.
322    ///
323    /// # Returns
324    ///
325    /// The tensor with the dimensions swapped.
326    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
327
328    /// Permutes the dimensions of a tensor.
329    ///
330    /// # Arguments
331    ///
332    /// * `tensor` - The tensor to permute the dimensions of.
333    /// * `axes` - The new order of the dimensions.
334    /// # Returns
335    ///
336    /// The tensor with the dimensions permuted.
337    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
338
339    /// Reverse the order of elements in a tensor along the given axes.
340    ///
341    /// # Arguments
342    ///
343    /// * `tensor` - The tensor to reverse.
344    /// * `axes` - The axes to reverse.
345    ///
346    /// The tensor with the elements reversed.
347    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
348
349    /// Tests if any element in the boolean `tensor` evaluates to True.
350    ///
351    /// # Arguments
352    ///
353    /// * `tensor` - The tensor to test.
354    ///
355    /// # Returns
356    ///
357    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
358    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
359        let sum = B::int_sum(B::bool_into_int(tensor));
360        B::int_greater_elem(sum, 0.elem())
361    }
362
363    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
364    ///
365    /// # Arguments
366    ///
367    /// * `tensor` - The tensor to test.
368    /// * `dim` - The axis along which to test.
369    ///
370    /// # Returns
371    ///
372    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
373    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
374    /// evaluates to True, False otherwise.
375    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
376        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
377        B::int_greater_elem(sum, 0.elem())
378    }
379
380    /// Tests if all elements in the boolean `tensor` evaluate to True.
381    ///
382    /// # Arguments
383    ///
384    /// * `tensor` - The tensor to test.
385    ///
386    /// # Returns
387    ///
388    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
389    /// evaluate to True, False otherwise.
390    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
391        let num_elems = tensor.shape().num_elements();
392        let sum = B::int_sum(B::bool_into_int(tensor));
393        B::int_equal_elem(sum, (num_elems as i32).elem())
394    }
395
396    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
397    ///
398    /// # Arguments
399    ///
400    /// * `tensor` - The tensor to test.
401    /// * `dim` - The axis along which to test.
402    ///
403    /// # Returns
404    ///
405    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
406    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
407    /// evaluates to True, False otherwise.
408    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
409        let num_elems = tensor.shape().dims[dim];
410        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
411        B::int_equal_elem(sum, (num_elems as i32).elem())
412    }
413
414    /// Compute the indices of the elements that are non-zero, grouped by element.
415    ///
416    /// # Arguments
417    ///
418    /// * `tensor` - The input tensor.
419    ///
420    /// # Returns
421    ///
422    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
423    /// Each row contains the indices of a non-zero element.
424    fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
425        async {
426            // Size of each output tensor is variable (= number of nonzero elements in the tensor).
427            // Reading the data to count the number of truth values might cause sync but is required.
428            let device = B::bool_device(&tensor);
429            let data = B::bool_into_data(tensor).await;
430            argwhere_data::<B>(data, &device)
431        }
432    }
433
434    /// Broadcasts the bool `tensor` to the given `shape`.
435    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
436
437    /// Unfold windows along a dimension.
438    ///
439    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
440    /// where windows are advanced by `step` at each index.
441    ///
442    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
443    ///
444    /// # Arguments
445    ///
446    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
447    /// * `dim` - the selected dim.
448    /// * `size` - the size of each unfolded window.
449    /// * `step` - the step between each window.
450    ///
451    /// # Returns
452    ///
453    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
454    fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
455}