burn_tensor/tensor/ops/bool_tensor.rs
1use super::{
2 BoolTensor, Device, FloatTensor, IntTensor, cat::cat_with_slice_assign,
3 repeat_dim::repeat_with_slice_assign,
4};
5use crate::{
6 Bool, ElementConversion, TensorData, TensorMetadata, argwhere_data, backend::Backend, chunk,
7 narrow, split, split_with_sizes, tensor::Shape,
8};
9use alloc::{vec, vec::Vec};
10use core::{future::Future, ops::Range};
11
12/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
13/// for documentation on each function.
14pub trait BoolTensorOps<B: Backend> {
15 /// Creates a new bool tensor.
16 ///
17 /// # Arguments
18 ///
19 /// * `shape` - The shape of the tensor.
20 /// * `device` - The device to create the tensor on.
21 ///
22 /// # Returns
23 ///
24 /// The boolean tensor with the given shape.
25 fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
26
27 /// Converts the tensor to a data structure.
28 ///
29 /// # Arguments
30 ///
31 /// * `tensor` - The tensor.
32 ///
33 /// # Returns
34 ///
35 /// The data structure with the tensor's data.
36 fn bool_into_data(tensor: BoolTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
37
38 /// Creates a tensor from the data structure.
39 ///
40 /// # Arguments
41 ///
42 /// * `data` - The data structure.
43 /// * `device` - The device to create the tensor on.
44 ///
45 /// # Returns
46 ///
47 /// The tensor with the data.
48 fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
49
50 /// Converts bool tensor to int tensor.
51 ///
52 /// # Arguments
53 ///
54 /// * `tensor` - The tensor.
55 ///
56 /// # Returns
57 ///
58 /// The int tensor with the same data as the bool tensor.
59 fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
60
61 /// Converts bool tensor to float tensor.
62 ///
63 /// # Arguments
64 ///
65 /// * `tensor` - The tensor.
66 ///
67 /// # Returns
68 ///
69 /// The float tensor with the same data as the bool tensor.
70 fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
71
72 /// Gets the device of the tensor.
73 ///
74 /// # Arguments
75 ///
76 /// * `tensor` - The tensor.
77 ///
78 /// # Returns
79 ///
80 /// The device of the tensor.
81 fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
82
83 /// Moves the tensor to the device.
84 fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
85
86 /// Reshapes the tensor.
87 ///
88 /// # Arguments
89 ///
90 /// * `tensor` - The tensor.
91 /// * `shape` - The new shape.
92 ///
93 /// # Returns
94 ///
95 /// The tensor with the new shape.
96 fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
97
98 /// Gets the values from the tensor for the given ranges.
99 ///
100 /// # Arguments
101 ///
102 /// * `tensor` - The tensor.
103 /// * `ranges` - The ranges to get the values from.
104 ///
105 /// # Returns
106 ///
107 /// The tensor with the values for the given ranges.
108 fn bool_slice(tensor: BoolTensor<B>, ranges: &[Range<usize>]) -> BoolTensor<B>;
109
110 /// Sets the values in the tensor for the given ranges.
111 ///
112 /// # Arguments
113 ///
114 /// * `tensor` - The tensor.
115 /// * `ranges` - The ranges to set the values for.
116 /// * `value` - The values to set.
117 ///
118 /// # Returns
119 ///
120 /// The tensor with the values set for the given ranges.
121 fn bool_slice_assign(
122 tensor: BoolTensor<B>,
123 ranges: &[Range<usize>],
124 value: BoolTensor<B>,
125 ) -> BoolTensor<B>;
126
127 /// Repeats one dimension of the tensor a given number of times along that dimension.
128 ///
129 /// # Arguments
130 ///
131 /// * `tensor` - The tensor.
132 /// * `dim` - The dimension to repeat.
133 /// * `times` - The number of times to repeat the dimension.
134 ///
135 /// # Returns
136 ///
137 /// The tensor with the dimension repeated.
138 fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
139 repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
140 }
141
142 /// Concatenates the tensors along the given dimension.
143 ///
144 /// # Arguments
145 ///
146 /// * `tensors` - The tensors to concatenate.
147 /// * `dim` - The dimension to concatenate along.
148 ///
149 /// # Returns
150 ///
151 /// The tensor with the tensors concatenated along the given dimension.
152 fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
153 cat_with_slice_assign::<B, Bool>(tensors, dim)
154 }
155
156 /// Equates the two tensors.
157 ///
158 /// # Arguments
159 ///
160 /// * `lhs` - The left hand side tensor.
161 /// * `rhs` - The right hand side tensor.
162 ///
163 /// # Returns
164 ///
165 /// The tensor with the result of the equate.
166 fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
167
168 /// Element-wise non-equality comparison.
169 ///
170 /// # Arguments
171 ///
172 /// * `lhs` - The left hand side tensor.
173 /// * `rhs` - The right hand side tensor.
174 ///
175 /// # Returns
176 ///
177 /// The tensor with the result of the comparison.
178 fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
179 let equal_tensor = B::bool_equal(lhs, rhs);
180 B::bool_not(equal_tensor)
181 }
182
183 /// Inverses boolean values.
184 ///
185 /// # Arguments
186 ///
187 /// * `tensor` - The tensor.
188 ///
189 /// # Returns
190 ///
191 /// The tensor with the result of the negation.
192 fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
193
194 /// Executes the logical and (`&&`) operation on two boolean tensors.
195 ///
196 /// # Arguments
197 ///
198 /// * `lhs` - The left hand side tensor.
199 /// * `rhs` - The right hand side tensor.
200 ///
201 /// # Returns
202 ///
203 /// The tensor with the result of the logical and.
204 fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
205
206 /// Executes the logical or (`||`) operation on two boolean tensors.
207 ///
208 /// # Arguments
209 ///
210 /// * `lhs` - The left hand side tensor.
211 /// * `rhs` - The right hand side tensor.
212 ///
213 /// # Returns
214 ///
215 /// The tensor with the result of the logical or.
216 fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
217
218 /// Transposes a bool tensor.
219 ///
220 /// # Arguments
221 ///
222 /// * `tensor` - The tensor to transpose.
223 ///
224 /// # Returns
225 ///
226 /// The transposed tensor.
227 fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
228 let ndims = tensor.shape().num_dims();
229 Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
230 }
231
232 /// Swaps two dimensions of a bool tensor.
233 ///
234 /// # Arguments
235 ///
236 /// * `tensor` - The tensor to swap the dimensions of.
237 /// * `dim1` - The first dimension to swap.
238 /// * `dim2` - The second dimension to swap.
239 ///
240 /// # Returns
241 ///
242 /// The tensor with the dimensions swapped.
243 fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
244
245 /// Permutes the dimensions of a tensor.
246 ///
247 /// # Arguments
248 ///
249 /// * `tensor` - The tensor to permute the dimensions of.
250 /// * `axes` - The new order of the dimensions.
251 /// # Returns
252 ///
253 /// The tensor with the dimensions permuted.
254 fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
255
256 /// Reverse the order of elements in a tensor along the given axes.
257 ///
258 /// # Arguments
259 ///
260 /// * `tensor` - The tensor to reverse.
261 /// * `axes` - The axes to reverse.
262 ///
263 /// The tensor with the elements reversed.
264 fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
265
266 /// Returns a new tensor with the given dimension narrowed to the given range.
267 ///
268 /// # Arguments
269 ///
270 /// * `dim` - The dimension along which the tensor will be narrowed.
271 /// * `start` - The starting point of the given range.
272 /// * `length` - The ending point of the given range.
273 /// # Panics
274 ///
275 /// - If the dimension is greater than the number of dimensions of the tensor.
276 /// - If the given range exceeds the number of elements on the given dimension.
277 ///
278 /// # Returns
279 ///
280 /// A new tensor with the given dimension narrowed to the given range.
281 fn bool_narrow(
282 tensor: BoolTensor<B>,
283 dim: usize,
284 start: usize,
285 length: usize,
286 ) -> BoolTensor<B> {
287 narrow::<B, Bool>(tensor, dim, start, length)
288 }
289
290 /// Split the tensor along the given dimension into chunks.
291 ///
292 /// # Arguments
293 ///
294 /// * `tensor` - The tensor.
295 /// * `chunks` - The number of chunks to be produced.
296 /// * `times` - The dimension along which the tensor will be split.
297 ///
298 /// # Returns
299 ///
300 /// A vector of tensors.
301 fn bool_chunk(tensor: BoolTensor<B>, chunks: usize, dim: usize) -> Vec<BoolTensor<B>> {
302 chunk::<B, Bool>(tensor, chunks, dim)
303 }
304
305 /// Split the tensor along the given dimension into chunks of `split_size`.
306 ///
307 /// # Arguments
308 ///
309 /// * `tensor` - The tensor.
310 /// * `split_size` - The size of a single chunk.
311 /// * `times` - The dimension along which the tensor will be split.
312 ///
313 /// # Returns
314 ///
315 /// A vector of tensors.
316 fn bool_split(tensor: BoolTensor<B>, split_size: usize, dim: usize) -> Vec<BoolTensor<B>> {
317 split::<B, Bool>(tensor, split_size, dim)
318 }
319
320 /// Split the tensor along the given dimension into chunks with sizes in
321 /// `dim` according to `split_sizes`.
322 ///
323 /// # Arguments
324 ///
325 /// * `tensor` - The tensor.
326 /// * `split_sizes` - Vector of sizes for each chunk.
327 /// * `times` - The dimension along which the tensor will be split.
328 ///
329 /// # Returns
330 ///
331 /// A vector of tensors.
332 fn bool_split_with_sizes(
333 tensor: BoolTensor<B>,
334 split_sizes: Vec<usize>,
335 dim: usize,
336 ) -> Vec<BoolTensor<B>> {
337 split_with_sizes::<B, Bool>(tensor, split_sizes, dim)
338 }
339
340 /// Tests if any element in the boolean `tensor` evaluates to True.
341 ///
342 /// # Arguments
343 ///
344 /// * `tensor` - The tensor to test.
345 ///
346 /// # Returns
347 ///
348 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
349 fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
350 let sum = B::int_sum(B::bool_into_int(tensor));
351 B::int_greater_elem(sum, 0.elem())
352 }
353
354 /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
355 ///
356 /// # Arguments
357 ///
358 /// * `tensor` - The tensor to test.
359 /// * `dim` - The axis along which to test.
360 ///
361 /// # Returns
362 ///
363 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
364 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
365 /// evaluates to True, False otherwise.
366 fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
367 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
368 B::int_greater_elem(sum, 0.elem())
369 }
370
371 /// Tests if all elements in the boolean `tensor` evaluate to True.
372 ///
373 /// # Arguments
374 ///
375 /// * `tensor` - The tensor to test.
376 ///
377 /// # Returns
378 ///
379 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
380 /// evaluate to True, False otherwise.
381 fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
382 let num_elems = tensor.shape().num_elements();
383 let sum = B::int_sum(B::bool_into_int(tensor));
384 B::int_equal_elem(sum, (num_elems as i32).elem())
385 }
386
387 /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
388 ///
389 /// # Arguments
390 ///
391 /// * `tensor` - The tensor to test.
392 /// * `dim` - The axis along which to test.
393 ///
394 /// # Returns
395 ///
396 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
397 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
398 /// evaluates to True, False otherwise.
399 fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
400 let num_elems = tensor.shape().dims[dim];
401 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
402 B::int_equal_elem(sum, (num_elems as i32).elem())
403 }
404
405 /// Compute the indices of the elements that are non-zero, grouped by element.
406 ///
407 /// # Arguments
408 ///
409 /// * `tensor` - The input tensor.
410 ///
411 /// # Returns
412 ///
413 /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
414 /// Each row contains the indices of a non-zero element.
415 fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
416 async {
417 // Size of each output tensor is variable (= number of nonzero elements in the tensor).
418 // Reading the data to count the number of truth values might cause sync but is required.
419 let device = B::bool_device(&tensor);
420 let data = B::bool_into_data(tensor).await;
421 argwhere_data::<B>(data, &device)
422 }
423 }
424
425 /// Compute the indices of the elements that are non-zero.
426 ///
427 /// # Arguments
428 ///
429 /// * `tensor` - The input tensor.
430 ///
431 /// # Returns
432 ///
433 /// A vector of tensors, one for each dimension of the given tensor, containing the indices of
434 /// the non-zero elements in that dimension. If all elements are zero, the vector is empty.
435 fn bool_nonzero(
436 tensor: BoolTensor<B>,
437 ) -> impl Future<Output = Vec<IntTensor<B>>> + 'static + Send {
438 async {
439 let indices = B::bool_argwhere(tensor).await;
440
441 if indices.shape().num_elements() == 0 {
442 // Return empty vec when all elements are zero
443 return vec![];
444 }
445
446 let dims = indices.shape().dims;
447 B::int_chunk(indices, dims[1], 1)
448 .into_iter()
449 .map(|t| B::int_reshape(t, Shape::new([dims[0]])))
450 .collect()
451 }
452 }
453
454 /// Broadcasts the bool `tensor` to the given `shape`.
455 fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
456}