burn_backend/backend/ops/bool_tensor.rs
1use super::{
2 argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,
3};
4use crate::ExecutionError;
5use crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};
6use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
7use alloc::vec::Vec;
8use burn_std::{Shape, Slice};
9use core::future::Future;
10
11/// Bool Tensor API for basic operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait BoolTensorOps<B: Backend> {
16 /// Creates a new bool tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 ///
23 /// # Returns
24 ///
25 /// The boolean tensor with the given shape.
26 fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
27
28 /// Creates a new bool tensor filled false.
29 ///
30 /// # Arguments
31 ///
32 /// * `shape` - The shape of the tensor.
33 /// * `device` - The device to create the tensor on.
34 ///
35 /// # Returns
36 ///
37 /// The boolean tensor filled with false.
38 fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
39
40 /// Creates a new bool tensor filled true.
41 ///
42 /// # Arguments
43 ///
44 /// * `shape` - The shape of the tensor.
45 /// * `device` - The device to create the tensor on.
46 ///
47 /// # Returns
48 ///
49 /// The boolean tensor filled with true.
50 fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
51
52 /// Converts the tensor to a data structure.
53 ///
54 /// # Arguments
55 ///
56 /// * `tensor` - The tensor.
57 ///
58 /// # Returns
59 ///
60 /// The data structure with the tensor's data.
61 fn bool_into_data(
62 tensor: BoolTensor<B>,
63 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
64
65 /// Creates a tensor from the data structure.
66 ///
67 /// # Arguments
68 ///
69 /// * `data` - The data structure.
70 /// * `device` - The device to create the tensor on.
71 ///
72 /// # Returns
73 ///
74 /// The tensor with the data.
75 fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
76
77 /// Converts bool tensor to int tensor.
78 ///
79 /// # Arguments
80 ///
81 /// * `tensor` - The tensor.
82 ///
83 /// # Returns
84 ///
85 /// The int tensor with the same data as the bool tensor.
86 fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
87
88 /// Converts bool tensor to float tensor.
89 ///
90 /// # Arguments
91 ///
92 /// * `tensor` - The tensor.
93 ///
94 /// # Returns
95 ///
96 /// The float tensor with the same data as the bool tensor.
97 fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
98
99 /// Gets the device of the tensor.
100 ///
101 /// # Arguments
102 ///
103 /// * `tensor` - The tensor.
104 ///
105 /// # Returns
106 ///
107 /// The device of the tensor.
108 fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
109
110 /// Moves the tensor to the device.
111 fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
112
113 /// Reshapes the tensor.
114 ///
115 /// # Arguments
116 ///
117 /// * `tensor` - The tensor.
118 /// * `shape` - The new shape.
119 ///
120 /// # Returns
121 ///
122 /// The tensor with the new shape.
123 fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
124
125 /// Gets the values from the tensor for the given ranges.
126 ///
127 /// # Arguments
128 ///
129 /// * `tensor` - The tensor.
130 /// * `slices` - The slices specifying ranges and steps for each dimension.
131 ///
132 /// # Returns
133 ///
134 /// The tensor with the values for the given slices.
135 ///
136 /// # Note
137 ///
138 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
139 /// be passed to this method. Backend implementations do not need to handle empty slices.
140 fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;
141
142 /// Sets the values in the tensor for the given ranges.
143 ///
144 /// # Arguments
145 ///
146 /// * `tensor` - The tensor.
147 /// * `ranges` - The ranges to set the values for.
148 /// * `value` - The values to set.
149 ///
150 /// # Returns
151 ///
152 /// The tensor with the values set for the given ranges.
153 ///
154 /// # Note
155 ///
156 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
157 /// high-level tensor API and will not be passed to this method. Backend implementations do
158 /// not need to handle empty slice assignments.
159 fn bool_slice_assign(
160 tensor: BoolTensor<B>,
161 slices: &[Slice],
162 value: BoolTensor<B>,
163 ) -> BoolTensor<B>;
164
165 /// Select tensor elements along the given dimension corresponding to the given indices.
166 ///
167 /// # Arguments
168 ///
169 /// * `tensor` - The tensor to select from.
170 /// * `dim` - The dimension to select from.
171 /// * `indices` - The indices of the elements to select.
172 ///
173 /// # Returns
174 ///
175 /// The tensor with the selected elements.
176 fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
177 // Default implementation: convert to int, select, then convert back to bool
178 let int_tensor = B::bool_into_int(tensor);
179 let selected = B::int_select(int_tensor, dim, indices);
180 B::int_equal_elem(selected, 1_i32.elem())
181 }
182
183 /// Assign the selected elements along the given dimension corresponding to the given indices
184 /// to the given value using sum reduction.
185 ///
186 /// # Arguments
187 ///
188 /// * `tensor` - The tensor to assign the values to.
189 /// * `dim` - The dimension to select from.
190 /// * `indices` - The indices of the elements to assign.
191 /// * `value` - The values to assign.
192 ///
193 /// # Returns
194 ///
195 /// The tensor with the assigned values.
196 fn bool_select_add(
197 tensor: BoolTensor<B>,
198 dim: usize,
199 indices: IntTensor<B>,
200 value: BoolTensor<B>,
201 ) -> BoolTensor<B> {
202 // Default implementation: convert to int, select_assign, then convert back to bool
203 let int_tensor = B::bool_into_int(tensor);
204 let int_values = B::bool_into_int(value);
205 let assigned = B::int_select_add(int_tensor, dim, indices, int_values);
206 // After select_assign with sum reduction, any non-zero value should be true
207 B::int_greater_elem(assigned, 0_i32.elem())
208 }
209
210 /// Repeats one dimension of the tensor a given number of times along that dimension.
211 ///
212 /// # Arguments
213 ///
214 /// * `tensor` - The tensor.
215 /// * `dim` - The dimension to repeat.
216 /// * `times` - The number of times to repeat the dimension.
217 ///
218 /// # Returns
219 ///
220 /// The tensor with the dimension repeated.
221 fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
222 repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
223 }
224
225 /// Concatenates the tensors along the given dimension.
226 ///
227 /// # Arguments
228 ///
229 /// * `tensors` - The tensors to concatenate.
230 /// * `dim` - The dimension to concatenate along.
231 ///
232 /// # Returns
233 ///
234 /// The tensor with the tensors concatenated along the given dimension.
235 ///
236 /// # Note
237 ///
238 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
239 /// high-level tensor API and will not be passed to this method. Backend implementations do
240 /// not need to handle empty tensors.
241 fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
242 cat_with_slice_assign::<B, Bool>(tensors, dim)
243 }
244
245 /// Equates the two tensors.
246 ///
247 /// # Arguments
248 ///
249 /// * `lhs` - The left hand side tensor.
250 /// * `rhs` - The right hand side tensor.
251 ///
252 /// # Returns
253 ///
254 /// The tensor with the result of the equate.
255 fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
256
257 /// Element-wise non-equality comparison.
258 ///
259 /// # Arguments
260 ///
261 /// * `lhs` - The left hand side tensor.
262 /// * `rhs` - The right hand side tensor.
263 ///
264 /// # Returns
265 ///
266 /// The tensor with the result of the comparison.
267 fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
268 let equal_tensor = B::bool_equal(lhs, rhs);
269 B::bool_not(equal_tensor)
270 }
271
272 /// Inverses boolean values.
273 ///
274 /// # Arguments
275 ///
276 /// * `tensor` - The tensor.
277 ///
278 /// # Returns
279 ///
280 /// The tensor with the result of the negation.
281 fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
282
283 /// Executes the logical and (`&&`) operation on two boolean tensors.
284 ///
285 /// # Arguments
286 ///
287 /// * `lhs` - The left hand side tensor.
288 /// * `rhs` - The right hand side tensor.
289 ///
290 /// # Returns
291 ///
292 /// The tensor with the result of the logical and.
293 fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
294
295 /// Executes the logical or (`||`) operation on two boolean tensors.
296 ///
297 /// # Arguments
298 ///
299 /// * `lhs` - The left hand side tensor.
300 /// * `rhs` - The right hand side tensor.
301 ///
302 /// # Returns
303 ///
304 /// The tensor with the result of the logical or.
305 fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
306
307 /// Element-wise exclusive or.
308 ///
309 /// # Arguments
310 ///
311 /// * `lhs` - The left hand side tensor.
312 /// * `rhs` - The right hand side tensor.
313 ///
314 /// # Returns
315 ///
316 /// The tensor with the result of the comparison.
317 fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
318 Self::bool_not_equal(lhs, rhs)
319 }
320
321 /// Transposes a bool tensor.
322 ///
323 /// # Arguments
324 ///
325 /// * `tensor` - The tensor to transpose.
326 ///
327 /// # Returns
328 ///
329 /// The transposed tensor.
330 fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
331 let ndims = tensor.shape().num_dims();
332 Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
333 }
334
335 /// Swaps two dimensions of a bool tensor.
336 ///
337 /// # Arguments
338 ///
339 /// * `tensor` - The tensor to swap the dimensions of.
340 /// * `dim1` - The first dimension to swap.
341 /// * `dim2` - The second dimension to swap.
342 ///
343 /// # Returns
344 ///
345 /// The tensor with the dimensions swapped.
346 fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
347
348 /// Permutes the dimensions of a tensor.
349 ///
350 /// # Arguments
351 ///
352 /// * `tensor` - The tensor to permute the dimensions of.
353 /// * `axes` - The new order of the dimensions.
354 /// # Returns
355 ///
356 /// The tensor with the dimensions permuted.
357 fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
358
359 /// Reverse the order of elements in a tensor along the given axes.
360 ///
361 /// # Arguments
362 ///
363 /// * `tensor` - The tensor to reverse.
364 /// * `axes` - The axes to reverse.
365 ///
366 /// The tensor with the elements reversed.
367 fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
368
369 /// Tests if any element in the boolean `tensor` evaluates to True.
370 ///
371 /// # Arguments
372 ///
373 /// * `tensor` - The tensor to test.
374 ///
375 /// # Returns
376 ///
377 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
378 fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
379 let sum = B::int_sum(B::bool_into_int(tensor));
380 B::int_greater_elem(sum, 0.elem())
381 }
382
383 /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
384 ///
385 /// # Arguments
386 ///
387 /// * `tensor` - The tensor to test.
388 /// * `dim` - The axis along which to test.
389 ///
390 /// # Returns
391 ///
392 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
393 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
394 /// evaluates to True, False otherwise.
395 fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
396 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
397 B::int_greater_elem(sum, 0.elem())
398 }
399
400 /// Tests if all elements in the boolean `tensor` evaluate to True.
401 ///
402 /// # Arguments
403 ///
404 /// * `tensor` - The tensor to test.
405 ///
406 /// # Returns
407 ///
408 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
409 /// evaluate to True, False otherwise.
410 fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
411 let num_elems = tensor.shape().num_elements();
412 let sum = B::int_sum(B::bool_into_int(tensor));
413 B::int_equal_elem(sum, (num_elems as i32).elem())
414 }
415
416 /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
417 ///
418 /// # Arguments
419 ///
420 /// * `tensor` - The tensor to test.
421 /// * `dim` - The axis along which to test.
422 ///
423 /// # Returns
424 ///
425 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
426 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
427 /// evaluates to True, False otherwise.
428 fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
429 let num_elems = tensor.shape().dims[dim];
430 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
431 B::int_equal_elem(sum, (num_elems as i32).elem())
432 }
433
434 /// Compute the indices of the elements that are non-zero, grouped by element.
435 ///
436 /// # Arguments
437 ///
438 /// * `tensor` - The input tensor.
439 ///
440 /// # Returns
441 ///
442 /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
443 /// Each row contains the indices of a non-zero element.
444 fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
445 async {
446 // Size of each output tensor is variable (= number of nonzero elements in the tensor).
447 // Reading the data to count the number of truth values might cause sync but is required.
448 let device = B::bool_device(&tensor);
449 let data = B::bool_into_data(tensor)
450 .await
451 .expect("Can read the data without error");
452 argwhere_data::<B>(data, &device)
453 }
454 }
455
456 /// Broadcasts the bool `tensor` to the given `shape`.
457 fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
458
459 /// Unfold windows along a dimension.
460 ///
461 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
462 /// where windows are advanced by `step` at each index.
463 ///
464 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
465 ///
466 /// # Arguments
467 ///
468 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
469 /// * `dim` - the selected dim.
470 /// * `size` - the size of each unfolded window.
471 /// * `step` - the step between each window.
472 ///
473 /// # Returns
474 ///
475 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
476 fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
477}