burn_tensor/tensor/ops/
bool_tensor.rs

1use super::{
2    BoolTensor, Device, FloatTensor, IntTensor, cat::cat_with_slice_assign,
3    repeat_dim::repeat_with_slice_assign,
4};
5use crate::{
6    Bool, ElementConversion, TensorData, TensorMetadata, argwhere_data, backend::Backend, chunk,
7    narrow, split, split_with_sizes, tensor::Shape,
8};
9use alloc::{vec, vec::Vec};
10use core::{future::Future, ops::Range};
11
12/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
13/// for documentation on each function.
14pub trait BoolTensorOps<B: Backend> {
15    /// Creates a new bool tensor.
16    ///
17    /// # Arguments
18    ///
19    /// * `shape` - The shape of the tensor.
20    /// * `device` - The device to create the tensor on.
21    ///
22    /// # Returns
23    ///
24    /// The boolean tensor with the given shape.
25    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
26
27    /// Converts the tensor to a data structure.
28    ///
29    /// # Arguments
30    ///
31    /// * `tensor` - The tensor.
32    ///
33    /// # Returns
34    ///
35    /// The data structure with the tensor's data.
36    fn bool_into_data(tensor: BoolTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
37
38    /// Creates a tensor from the data structure.
39    ///
40    /// # Arguments
41    ///
42    /// * `data` - The data structure.
43    /// * `device` - The device to create the tensor on.
44    ///
45    /// # Returns
46    ///
47    /// The tensor with the data.
48    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
49
50    /// Converts bool tensor to int tensor.
51    ///
52    /// # Arguments
53    ///
54    /// * `tensor` - The tensor.
55    ///
56    /// # Returns
57    ///
58    /// The int tensor with the same data as the bool tensor.
59    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
60
61    /// Converts bool tensor to float tensor.
62    ///
63    /// # Arguments
64    ///
65    /// * `tensor` - The tensor.
66    ///
67    /// # Returns
68    ///
69    /// The float tensor with the same data as the bool tensor.
70    fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
71
72    /// Gets the device of the tensor.
73    ///
74    /// # Arguments
75    ///
76    /// * `tensor` - The tensor.
77    ///
78    /// # Returns
79    ///
80    /// The device of the tensor.
81    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
82
83    /// Moves the tensor to the device.
84    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
85
86    /// Reshapes the tensor.
87    ///
88    /// # Arguments
89    ///
90    /// * `tensor` - The tensor.
91    /// * `shape` - The new shape.
92    ///
93    /// # Returns
94    ///
95    /// The tensor with the new shape.
96    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
97
98    /// Gets the values from the tensor for the given ranges.
99    ///
100    /// # Arguments
101    ///
102    /// * `tensor` - The tensor.
103    /// * `ranges` - The ranges to get the values from.
104    ///
105    /// # Returns
106    ///
107    /// The tensor with the values for the given ranges.
108    fn bool_slice(tensor: BoolTensor<B>, ranges: &[Range<usize>]) -> BoolTensor<B>;
109
110    /// Sets the values in the tensor for the given ranges.
111    ///
112    /// # Arguments
113    ///
114    /// * `tensor` - The tensor.
115    /// * `ranges` - The ranges to set the values for.
116    /// * `value` - The values to set.
117    ///
118    /// # Returns
119    ///
120    /// The tensor with the values set for the given ranges.
121    fn bool_slice_assign(
122        tensor: BoolTensor<B>,
123        ranges: &[Range<usize>],
124        value: BoolTensor<B>,
125    ) -> BoolTensor<B>;
126
127    /// Repeats one dimension of the tensor a given number of times along that dimension.
128    ///
129    /// # Arguments
130    ///
131    /// * `tensor` - The tensor.
132    /// * `dim` - The dimension to repeat.
133    /// * `times` - The number of times to repeat the dimension.
134    ///
135    /// # Returns
136    ///
137    /// The tensor with the dimension repeated.
138    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
139        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
140    }
141
142    /// Concatenates the tensors along the given dimension.
143    ///
144    /// # Arguments
145    ///
146    /// * `tensors` - The tensors to concatenate.
147    /// * `dim` - The dimension to concatenate along.
148    ///
149    /// # Returns
150    ///
151    /// The tensor with the tensors concatenated along the given dimension.
152    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
153        cat_with_slice_assign::<B, Bool>(tensors, dim)
154    }
155
156    /// Equates the two tensors.
157    ///
158    /// # Arguments
159    ///
160    /// * `lhs` - The left hand side tensor.
161    /// * `rhs` - The right hand side tensor.
162    ///
163    /// # Returns
164    ///
165    /// The tensor with the result of the equate.
166    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
167
168    /// Element-wise non-equality comparison.
169    ///
170    /// # Arguments
171    ///
172    /// * `lhs` - The left hand side tensor.
173    /// * `rhs` - The right hand side tensor.
174    ///
175    /// # Returns
176    ///
177    /// The tensor with the result of the comparison.
178    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
179        let equal_tensor = B::bool_equal(lhs, rhs);
180        B::bool_not(equal_tensor)
181    }
182
183    /// Inverses boolean values.
184    ///
185    /// # Arguments
186    ///
187    /// * `tensor` - The tensor.
188    ///
189    /// # Returns
190    ///
191    /// The tensor with the result of the negation.
192    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
193
194    /// Executes the logical and (`&&`) operation on two boolean tensors.
195    ///
196    /// # Arguments
197    ///
198    /// * `lhs` - The left hand side tensor.
199    /// * `rhs` - The right hand side tensor.
200    ///
201    /// # Returns
202    ///
203    /// The tensor with the result of the logical and.
204    fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
205
206    /// Executes the logical or (`||`) operation on two boolean tensors.
207    ///
208    /// # Arguments
209    ///
210    /// * `lhs` - The left hand side tensor.
211    /// * `rhs` - The right hand side tensor.
212    ///
213    /// # Returns
214    ///
215    /// The tensor with the result of the logical or.
216    fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
217
218    /// Transposes a bool tensor.
219    ///
220    /// # Arguments
221    ///
222    /// * `tensor` - The tensor to transpose.
223    ///
224    /// # Returns
225    ///
226    /// The transposed tensor.
227    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
228        let ndims = tensor.shape().num_dims();
229        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
230    }
231
232    /// Swaps two dimensions of a bool tensor.
233    ///
234    /// # Arguments
235    ///
236    /// * `tensor` - The tensor to swap the dimensions of.
237    /// * `dim1` - The first dimension to swap.
238    /// * `dim2` - The second dimension to swap.
239    ///
240    /// # Returns
241    ///
242    /// The tensor with the dimensions swapped.
243    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
244
245    /// Permutes the dimensions of a tensor.
246    ///
247    /// # Arguments
248    ///
249    /// * `tensor` - The tensor to permute the dimensions of.
250    /// * `axes` - The new order of the dimensions.
251    /// # Returns
252    ///
253    /// The tensor with the dimensions permuted.
254    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
255
256    /// Reverse the order of elements in a tensor along the given axes.
257    ///
258    /// # Arguments
259    ///
260    /// * `tensor` - The tensor to reverse.
261    /// * `axes` - The axes to reverse.
262    ///
263    /// The tensor with the elements reversed.
264    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
265
266    /// Returns a new tensor with the given dimension narrowed to the given range.
267    ///
268    /// # Arguments
269    ///
270    /// * `dim` - The dimension along which the tensor will be narrowed.
271    /// * `start` - The starting point of the given range.
272    /// * `length` - The ending point of the given range.
273    /// # Panics
274    ///
275    /// - If the dimension is greater than the number of dimensions of the tensor.
276    /// - If the given range exceeds the number of elements on the given dimension.
277    ///
278    /// # Returns
279    ///
280    /// A new tensor with the given dimension narrowed to the given range.
281    fn bool_narrow(
282        tensor: BoolTensor<B>,
283        dim: usize,
284        start: usize,
285        length: usize,
286    ) -> BoolTensor<B> {
287        narrow::<B, Bool>(tensor, dim, start, length)
288    }
289
290    /// Split the tensor along the given dimension into chunks.
291    ///
292    /// # Arguments
293    ///
294    /// * `tensor` - The tensor.
295    /// * `chunks` - The number of chunks to be produced.
296    /// * `times` - The dimension along which the tensor will be split.
297    ///
298    /// # Returns
299    ///
300    /// A vector of tensors.
301    fn bool_chunk(tensor: BoolTensor<B>, chunks: usize, dim: usize) -> Vec<BoolTensor<B>> {
302        chunk::<B, Bool>(tensor, chunks, dim)
303    }
304
305    /// Split the tensor along the given dimension into chunks of `split_size`.
306    ///
307    /// # Arguments
308    ///
309    /// * `tensor` - The tensor.
310    /// * `split_size` - The size of a single chunk.
311    /// * `times` - The dimension along which the tensor will be split.
312    ///
313    /// # Returns
314    ///
315    /// A vector of tensors.
316    fn bool_split(tensor: BoolTensor<B>, split_size: usize, dim: usize) -> Vec<BoolTensor<B>> {
317        split::<B, Bool>(tensor, split_size, dim)
318    }
319
320    /// Split the tensor along the given dimension into chunks with sizes in
321    /// `dim` according to `split_sizes`.
322    ///
323    /// # Arguments
324    ///
325    /// * `tensor` - The tensor.
326    /// * `split_sizes` - Vector of sizes for each chunk.
327    /// * `times` - The dimension along which the tensor will be split.
328    ///
329    /// # Returns
330    ///
331    /// A vector of tensors.
332    fn bool_split_with_sizes(
333        tensor: BoolTensor<B>,
334        split_sizes: Vec<usize>,
335        dim: usize,
336    ) -> Vec<BoolTensor<B>> {
337        split_with_sizes::<B, Bool>(tensor, split_sizes, dim)
338    }
339
340    /// Tests if any element in the boolean `tensor` evaluates to True.
341    ///
342    /// # Arguments
343    ///
344    /// * `tensor` - The tensor to test.
345    ///
346    /// # Returns
347    ///
348    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
349    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
350        let sum = B::int_sum(B::bool_into_int(tensor));
351        B::int_greater_elem(sum, 0.elem())
352    }
353
354    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
355    ///
356    /// # Arguments
357    ///
358    /// * `tensor` - The tensor to test.
359    /// * `dim` - The axis along which to test.
360    ///
361    /// # Returns
362    ///
363    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
364    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
365    /// evaluates to True, False otherwise.
366    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
367        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
368        B::int_greater_elem(sum, 0.elem())
369    }
370
371    /// Tests if all elements in the boolean `tensor` evaluate to True.
372    ///
373    /// # Arguments
374    ///
375    /// * `tensor` - The tensor to test.
376    ///
377    /// # Returns
378    ///
379    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
380    /// evaluate to True, False otherwise.
381    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
382        let num_elems = tensor.shape().num_elements();
383        let sum = B::int_sum(B::bool_into_int(tensor));
384        B::int_equal_elem(sum, (num_elems as i32).elem())
385    }
386
387    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
388    ///
389    /// # Arguments
390    ///
391    /// * `tensor` - The tensor to test.
392    /// * `dim` - The axis along which to test.
393    ///
394    /// # Returns
395    ///
396    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
397    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
398    /// evaluates to True, False otherwise.
399    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
400        let num_elems = tensor.shape().dims[dim];
401        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
402        B::int_equal_elem(sum, (num_elems as i32).elem())
403    }
404
405    /// Compute the indices of the elements that are non-zero, grouped by element.
406    ///
407    /// # Arguments
408    ///
409    /// * `tensor` - The input tensor.
410    ///
411    /// # Returns
412    ///
413    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
414    /// Each row contains the indices of a non-zero element.
415    fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
416        async {
417            // Size of each output tensor is variable (= number of nonzero elements in the tensor).
418            // Reading the data to count the number of truth values might cause sync but is required.
419            let device = B::bool_device(&tensor);
420            let data = B::bool_into_data(tensor).await;
421            argwhere_data::<B>(data, &device)
422        }
423    }
424
425    /// Compute the indices of the elements that are non-zero.
426    ///
427    /// # Arguments
428    ///
429    /// * `tensor` - The input tensor.
430    ///
431    /// # Returns
432    ///
433    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
434    /// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
435    fn bool_nonzero(
436        tensor: BoolTensor<B>,
437    ) -> impl Future<Output = Vec<IntTensor<B>>> + 'static + Send {
438        async {
439            let indices = B::bool_argwhere(tensor).await;
440
441            if indices.shape().num_elements() == 0 {
442                // Return empty vec when all elements are zero
443                return vec![];
444            }
445
446            let dims = indices.shape().dims;
447            B::int_chunk(indices, dims[1], 1)
448                .into_iter()
449                .map(|t| B::int_reshape(t, Shape::new([dims[0]])))
450                .collect()
451        }
452    }
453
454    /// Broadcasts the bool `tensor` to the given `shape`.
455    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
456}