burn_tensor/tensor/ops/
bool_tensor.rs

1use super::{
2    cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign, BoolTensor, Device,
3    FloatTensor, IntTensor,
4};
5use crate::{
6    argwhere_data, backend::Backend, chunk, narrow, split, split_with_sizes, tensor::Shape, Bool,
7    ElementConversion, TensorData, TensorMetadata,
8};
9use alloc::{vec, vec::Vec};
10use core::{future::Future, ops::Range};
11
12/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
13/// for documentation on each function.
14pub trait BoolTensorOps<B: Backend> {
15    /// Creates a new bool tensor.
16    ///
17    /// # Arguments
18    ///
19    /// * `shape` - The shape of the tensor.
20    /// * `device` - The device to create the tensor on.
21    ///
22    /// # Returns
23    ///
24    /// The boolean tensor with the given shape.
25    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
26
27    /// Converts the tensor to a data structure.
28    ///
29    /// # Arguments
30    ///
31    /// * `tensor` - The tensor.
32    ///
33    /// # Returns
34    ///
35    /// The data structure with the tensor's data.
36    fn bool_into_data(tensor: BoolTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
37
38    /// Creates a tensor from the data structure.
39    ///
40    /// # Arguments
41    ///
42    /// * `data` - The data structure.
43    /// * `device` - The device to create the tensor on.
44    ///
45    /// # Returns
46    ///
47    /// The tensor with the data.
48    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
49
50    /// Converts bool tensor to int tensor.
51    ///
52    /// # Arguments
53    ///
54    /// * `tensor` - The tensor.
55    ///
56    /// # Returns
57    ///
58    /// The int tensor with the same data as the bool tensor.
59    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
60
61    /// Converts bool tensor to float tensor.
62    ///
63    /// # Arguments
64    ///
65    /// * `tensor` - The tensor.
66    ///
67    /// # Returns
68    ///
69    /// The float tensor with the same data as the bool tensor.
70    fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
71
72    /// Gets the device of the tensor.
73    ///
74    /// # Arguments
75    ///
76    /// * `tensor` - The tensor.
77    ///
78    /// # Returns
79    ///
80    /// The device of the tensor.
81    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
82
83    /// Moves the tensor to the device.
84    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
85
86    /// Reshapes the tensor.
87    ///
88    /// # Arguments
89    ///
90    /// * `tensor` - The tensor.
91    /// * `shape` - The new shape.
92    ///
93    /// # Returns
94    ///
95    /// The tensor with the new shape.
96    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
97
98    /// Gets the values from the tensor for the given ranges.
99    ///
100    /// # Arguments
101    ///
102    /// * `tensor` - The tensor.
103    /// * `ranges` - The ranges to get the values from.
104    ///
105    /// # Returns
106    ///
107    /// The tensor with the values for the given ranges.
108    fn bool_slice(tensor: BoolTensor<B>, ranges: &[Range<usize>]) -> BoolTensor<B>;
109
110    /// Sets the values in the tensor for the given ranges.
111    ///
112    /// # Arguments
113    ///
114    /// * `tensor` - The tensor.
115    /// * `ranges` - The ranges to set the values for.
116    /// * `value` - The values to set.
117    ///
118    /// # Returns
119    ///
120    /// The tensor with the values set for the given ranges.
121    fn bool_slice_assign(
122        tensor: BoolTensor<B>,
123        ranges: &[Range<usize>],
124        value: BoolTensor<B>,
125    ) -> BoolTensor<B>;
126
127    /// Repeats one dimension of the tensor a given number of times along that dimension.
128    ///
129    /// # Arguments
130    ///
131    /// * `tensor` - The tensor.
132    /// * `dim` - The dimension to repeat.
133    /// * `times` - The number of times to repeat the dimension.
134    ///
135    /// # Returns
136    ///
137    /// The tensor with the dimension repeated.
138    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
139        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
140    }
141
142    /// Concatenates the tensors along the given dimension.
143    ///
144    /// # Arguments
145    ///
146    /// * `tensors` - The tensors to concatenate.
147    /// * `dim` - The dimension to concatenate along.
148    ///
149    /// # Returns
150    ///
151    /// The tensor with the tensors concatenated along the given dimension.
152    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
153        cat_with_slice_assign::<B, Bool>(tensors, dim)
154    }
155
156    /// Equates the two tensors.
157    ///
158    /// # Arguments
159    ///
160    /// * `lhs` - The left hand side tensor.
161    /// * `rhs` - The right hand side tensor.
162    ///
163    /// # Returns
164    ///
165    /// The tensor with the result of the equate.
166    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
167
168    /// Element-wise non-equality comparison.
169    ///
170    /// # Arguments
171    ///
172    /// * `lhs` - The left hand side tensor.
173    /// * `rhs` - The right hand side tensor.
174    ///
175    /// # Returns
176    ///
177    /// The tensor with the result of the comparison.
178    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
179        let equal_tensor = B::bool_equal(lhs, rhs);
180        B::bool_not(equal_tensor)
181    }
182
183    /// Inverses boolean values.
184    ///
185    /// # Arguments
186    ///
187    /// * `tensor` - The tensor.
188    ///
189    /// # Returns
190    ///
191    /// The tensor with the result of the negation.
192    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
193
194    /// Transposes a bool tensor.
195    ///
196    /// # Arguments
197    ///
198    /// * `tensor` - The tensor to transpose.
199    ///
200    /// # Returns
201    ///
202    /// The transposed tensor.
203    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
204        let ndims = tensor.shape().num_dims();
205        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
206    }
207
208    /// Swaps two dimensions of a bool tensor.
209    ///
210    /// # Arguments
211    ///
212    /// * `tensor` - The tensor to swap the dimensions of.
213    /// * `dim1` - The first dimension to swap.
214    /// * `dim2` - The second dimension to swap.
215    ///
216    /// # Returns
217    ///
218    /// The tensor with the dimensions swapped.
219    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
220
221    /// Permutes the dimensions of a tensor.
222    ///
223    /// # Arguments
224    ///
225    /// * `tensor` - The tensor to permute the dimensions of.
226    /// * `axes` - The new order of the dimensions.
227    /// # Returns
228    ///
229    /// The tensor with the dimensions permuted.
230    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
231
232    /// Reverse the order of elements in a tensor along the given axes.
233    ///
234    /// # Arguments
235    ///
236    /// * `tensor` - The tensor to reverse.
237    /// * `axes` - The axes to reverse.
238    ///
239    /// The tensor with the elements reversed.
240    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
241
242    /// Returns a new tensor with the given dimension narrowed to the given range.
243    ///
244    /// # Arguments
245    ///
246    /// * `dim` - The dimension along which the tensor will be narrowed.
247    /// * `start` - The starting point of the given range.
248    /// * `length` - The ending point of the given range.
249    /// # Panics
250    ///
251    /// - If the dimension is greater than the number of dimensions of the tensor.
252    /// - If the given range exceeds the number of elements on the given dimension.
253    ///
254    /// # Returns
255    ///
256    /// A new tensor with the given dimension narrowed to the given range.
257    fn bool_narrow(
258        tensor: BoolTensor<B>,
259        dim: usize,
260        start: usize,
261        length: usize,
262    ) -> BoolTensor<B> {
263        narrow::<B, Bool>(tensor, dim, start, length)
264    }
265
266    /// Split the tensor along the given dimension into chunks.
267    ///
268    /// # Arguments
269    ///
270    /// * `tensor` - The tensor.
271    /// * `chunks` - The number of chunks to be produced.
272    /// * `times` - The dimension along which the tensor will be split.
273    ///
274    /// # Returns
275    ///
276    /// A vector of tensors.
277    fn bool_chunk(tensor: BoolTensor<B>, chunks: usize, dim: usize) -> Vec<BoolTensor<B>> {
278        chunk::<B, Bool>(tensor, chunks, dim)
279    }
280
281    /// Split the tensor along the given dimension into chunks of `split_size`.
282    ///
283    /// # Arguments
284    ///
285    /// * `tensor` - The tensor.
286    /// * `split_size` - The size of a single chunk.
287    /// * `times` - The dimension along which the tensor will be split.
288    ///
289    /// # Returns
290    ///
291    /// A vector of tensors.
292    fn bool_split(tensor: BoolTensor<B>, split_size: usize, dim: usize) -> Vec<BoolTensor<B>> {
293        split::<B, Bool>(tensor, split_size, dim)
294    }
295
296    /// Split the tensor along the given dimension into chunks with sizes in
297    /// `dim` according to `split_sizes`.
298    ///
299    /// # Arguments
300    ///
301    /// * `tensor` - The tensor.
302    /// * `split_sizes` - Vector of sizes for each chunk.
303    /// * `times` - The dimension along which the tensor will be split.
304    ///
305    /// # Returns
306    ///
307    /// A vector of tensors.
308    fn bool_split_with_sizes(
309        tensor: BoolTensor<B>,
310        split_sizes: Vec<usize>,
311        dim: usize,
312    ) -> Vec<BoolTensor<B>> {
313        split_with_sizes::<B, Bool>(tensor, split_sizes, dim)
314    }
315
316    /// Tests if any element in the boolean `tensor` evaluates to True.
317    ///
318    /// # Arguments
319    ///
320    /// * `tensor` - The tensor to test.
321    ///
322    /// # Returns
323    ///
324    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
325    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
326        let sum = B::int_sum(B::bool_into_int(tensor));
327        B::int_greater_elem(sum, 0.elem())
328    }
329
330    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
331    ///
332    /// # Arguments
333    ///
334    /// * `tensor` - The tensor to test.
335    /// * `dim` - The axis along which to test.
336    ///
337    /// # Returns
338    ///
339    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
340    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
341    /// evaluates to True, False otherwise.
342    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
343        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
344        B::int_greater_elem(sum, 0.elem())
345    }
346
347    /// Tests if all elements in the boolean `tensor` evaluate to True.
348    ///
349    /// # Arguments
350    ///
351    /// * `tensor` - The tensor to test.
352    ///
353    /// # Returns
354    ///
355    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
356    /// evaluate to True, False otherwise.
357    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
358        let num_elems = tensor.shape().num_elements();
359        let sum = B::int_sum(B::bool_into_int(tensor));
360        B::int_equal_elem(sum, (num_elems as i32).elem())
361    }
362
363    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
364    ///
365    /// # Arguments
366    ///
367    /// * `tensor` - The tensor to test.
368    /// * `dim` - The axis along which to test.
369    ///
370    /// # Returns
371    ///
372    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
373    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
374    /// evaluates to True, False otherwise.
375    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
376        let num_elems = tensor.shape().dims[dim];
377        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
378        B::int_equal_elem(sum, (num_elems as i32).elem())
379    }
380
381    /// Compute the indices of the elements that are non-zero, grouped by element.
382    ///
383    /// # Arguments
384    ///
385    /// * `tensor` - The input tensor.
386    ///
387    /// # Returns
388    ///
389    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
390    /// Each row contains the indices of a non-zero element.
391    fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
392        async {
393            // Size of each output tensor is variable (= number of nonzero elements in the tensor).
394            // Reading the data to count the number of truth values might cause sync but is required.
395            let device = B::bool_device(&tensor);
396            let data = B::bool_into_data(tensor).await;
397            argwhere_data::<B>(data, &device)
398        }
399    }
400
401    /// Compute the indices of the elements that are non-zero.
402    ///
403    /// # Arguments
404    ///
405    /// * `tensor` - The input tensor.
406    ///
407    /// # Returns
408    ///
409    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
410    /// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
411    fn bool_nonzero(
412        tensor: BoolTensor<B>,
413    ) -> impl Future<Output = Vec<IntTensor<B>>> + 'static + Send {
414        async {
415            let indices = B::bool_argwhere(tensor).await;
416
417            if indices.shape().num_elements() == 0 {
418                // Return empty vec when all elements are zero
419                return vec![];
420            }
421
422            let dims = indices.shape().dims;
423            B::int_chunk(indices, dims[1], 1)
424                .into_iter()
425                .map(|t| B::int_reshape(t, Shape::new([dims[0]])))
426                .collect()
427        }
428    }
429
430    /// Broadcasts the bool `tensor` to the given `shape`.
431    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
432}