burn_tensor/tensor/ops/bool_tensor.rs
1use super::{
2 BoolTensor, Device, FloatTensor, IntTensor, cat::cat_with_slice_assign,
3 repeat_dim::repeat_with_slice_assign,
4};
5use crate::{
6 Bool, ElementConversion, TensorData, TensorMetadata, argwhere_data, backend::Backend,
7 tensor::Shape,
8};
9use alloc::vec::Vec;
10use core::future::Future;
11
12/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
13/// for documentation on each function.
14pub trait BoolTensorOps<B: Backend> {
15 /// Creates a new bool tensor.
16 ///
17 /// # Arguments
18 ///
19 /// * `shape` - The shape of the tensor.
20 /// * `device` - The device to create the tensor on.
21 ///
22 /// # Returns
23 ///
24 /// The boolean tensor with the given shape.
25 fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
26
27 /// Creates a new bool tensor filled false.
28 ///
29 /// # Arguments
30 ///
31 /// * `shape` - The shape of the tensor.
32 /// * `device` - The device to create the tensor on.
33 ///
34 /// # Returns
35 ///
36 /// The boolean tensor filled with false.
37 fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
38
39 /// Creates a new bool tensor filled true.
40 ///
41 /// # Arguments
42 ///
43 /// * `shape` - The shape of the tensor.
44 /// * `device` - The device to create the tensor on.
45 ///
46 /// # Returns
47 ///
48 /// The boolean tensor filled with true.
49 fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;
50
51 /// Converts the tensor to a data structure.
52 ///
53 /// # Arguments
54 ///
55 /// * `tensor` - The tensor.
56 ///
57 /// # Returns
58 ///
59 /// The data structure with the tensor's data.
60 fn bool_into_data(tensor: BoolTensor<B>) -> impl Future<Output = TensorData> + Send;
61
62 /// Creates a tensor from the data structure.
63 ///
64 /// # Arguments
65 ///
66 /// * `data` - The data structure.
67 /// * `device` - The device to create the tensor on.
68 ///
69 /// # Returns
70 ///
71 /// The tensor with the data.
72 fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;
73
74 /// Converts bool tensor to int tensor.
75 ///
76 /// # Arguments
77 ///
78 /// * `tensor` - The tensor.
79 ///
80 /// # Returns
81 ///
82 /// The int tensor with the same data as the bool tensor.
83 fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;
84
85 /// Converts bool tensor to float tensor.
86 ///
87 /// # Arguments
88 ///
89 /// * `tensor` - The tensor.
90 ///
91 /// # Returns
92 ///
93 /// The float tensor with the same data as the bool tensor.
94 fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;
95
96 /// Gets the device of the tensor.
97 ///
98 /// # Arguments
99 ///
100 /// * `tensor` - The tensor.
101 ///
102 /// # Returns
103 ///
104 /// The device of the tensor.
105 fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;
106
107 /// Moves the tensor to the device.
108 fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;
109
110 /// Reshapes the tensor.
111 ///
112 /// # Arguments
113 ///
114 /// * `tensor` - The tensor.
115 /// * `shape` - The new shape.
116 ///
117 /// # Returns
118 ///
119 /// The tensor with the new shape.
120 fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
121
122 /// Gets the values from the tensor for the given ranges.
123 ///
124 /// # Arguments
125 ///
126 /// * `tensor` - The tensor.
127 /// * `slices` - The slices specifying ranges and steps for each dimension.
128 ///
129 /// # Returns
130 ///
131 /// The tensor with the values for the given slices.
132 fn bool_slice(tensor: BoolTensor<B>, slices: &[crate::Slice]) -> BoolTensor<B>;
133
134 /// Sets the values in the tensor for the given ranges.
135 ///
136 /// # Arguments
137 ///
138 /// * `tensor` - The tensor.
139 /// * `ranges` - The ranges to set the values for.
140 /// * `value` - The values to set.
141 ///
142 /// # Returns
143 ///
144 /// The tensor with the values set for the given ranges.
145 fn bool_slice_assign(
146 tensor: BoolTensor<B>,
147 slices: &[crate::Slice],
148 value: BoolTensor<B>,
149 ) -> BoolTensor<B>;
150
151 /// Select tensor elements along the given dimension corresponding to the given indices.
152 ///
153 /// # Arguments
154 ///
155 /// * `tensor` - The tensor to select from.
156 /// * `dim` - The dimension to select from.
157 /// * `indices` - The indices of the elements to select.
158 ///
159 /// # Returns
160 ///
161 /// The tensor with the selected elements.
162 fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {
163 // Default implementation: convert to int, select, then convert back to bool
164 let int_tensor = B::bool_into_int(tensor);
165 let selected = B::int_select(int_tensor, dim, indices);
166 B::int_equal_elem(selected, 1_i32.elem())
167 }
168
169 /// Assign the selected elements along the given dimension corresponding to the given indices
170 /// to the given value.
171 ///
172 /// # Arguments
173 ///
174 /// * `tensor` - The tensor to assign the values to.
175 /// * `dim` - The dimension to select from.
176 /// * `indices` - The indices of the elements to assign.
177 /// * `value` - The values to assign.
178 ///
179 /// # Returns
180 ///
181 /// The tensor with the assigned values.
182 fn bool_select_assign(
183 tensor: BoolTensor<B>,
184 dim: usize,
185 indices: IntTensor<B>,
186 value: BoolTensor<B>,
187 ) -> BoolTensor<B> {
188 // Default implementation: convert to int, select_assign, then convert back to bool
189 let int_tensor = B::bool_into_int(tensor);
190 let int_values = B::bool_into_int(value);
191 let assigned = B::int_select_assign(int_tensor, dim, indices, int_values);
192 // After select_assign with sum reduction, any non-zero value should be true
193 B::int_greater_elem(assigned, 0_i32.elem())
194 }
195
196 /// Repeats one dimension of the tensor a given number of times along that dimension.
197 ///
198 /// # Arguments
199 ///
200 /// * `tensor` - The tensor.
201 /// * `dim` - The dimension to repeat.
202 /// * `times` - The number of times to repeat the dimension.
203 ///
204 /// # Returns
205 ///
206 /// The tensor with the dimension repeated.
207 fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
208 repeat_with_slice_assign::<B, Bool>(tensor, dim, times)
209 }
210
211 /// Concatenates the tensors along the given dimension.
212 ///
213 /// # Arguments
214 ///
215 /// * `tensors` - The tensors to concatenate.
216 /// * `dim` - The dimension to concatenate along.
217 ///
218 /// # Returns
219 ///
220 /// The tensor with the tensors concatenated along the given dimension.
221 fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {
222 cat_with_slice_assign::<B, Bool>(tensors, dim)
223 }
224
225 /// Equates the two tensors.
226 ///
227 /// # Arguments
228 ///
229 /// * `lhs` - The left hand side tensor.
230 /// * `rhs` - The right hand side tensor.
231 ///
232 /// # Returns
233 ///
234 /// The tensor with the result of the equate.
235 fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
236
237 /// Element-wise non-equality comparison.
238 ///
239 /// # Arguments
240 ///
241 /// * `lhs` - The left hand side tensor.
242 /// * `rhs` - The right hand side tensor.
243 ///
244 /// # Returns
245 ///
246 /// The tensor with the result of the comparison.
247 fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
248 let equal_tensor = B::bool_equal(lhs, rhs);
249 B::bool_not(equal_tensor)
250 }
251
252 /// Inverses boolean values.
253 ///
254 /// # Arguments
255 ///
256 /// * `tensor` - The tensor.
257 ///
258 /// # Returns
259 ///
260 /// The tensor with the result of the negation.
261 fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;
262
263 /// Executes the logical and (`&&`) operation on two boolean tensors.
264 ///
265 /// # Arguments
266 ///
267 /// * `lhs` - The left hand side tensor.
268 /// * `rhs` - The right hand side tensor.
269 ///
270 /// # Returns
271 ///
272 /// The tensor with the result of the logical and.
273 fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
274
275 /// Executes the logical or (`||`) operation on two boolean tensors.
276 ///
277 /// # Arguments
278 ///
279 /// * `lhs` - The left hand side tensor.
280 /// * `rhs` - The right hand side tensor.
281 ///
282 /// # Returns
283 ///
284 /// The tensor with the result of the logical or.
285 fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;
286
287 /// Element-wise exclusive or.
288 ///
289 /// # Arguments
290 ///
291 /// * `lhs` - The left hand side tensor.
292 /// * `rhs` - The right hand side tensor.
293 ///
294 /// # Returns
295 ///
296 /// The tensor with the result of the comparison.
297 fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {
298 Self::bool_not_equal(lhs, rhs)
299 }
300
301 /// Transposes a bool tensor.
302 ///
303 /// # Arguments
304 ///
305 /// * `tensor` - The tensor to transpose.
306 ///
307 /// # Returns
308 ///
309 /// The transposed tensor.
310 fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {
311 let ndims = tensor.shape().num_dims();
312 Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)
313 }
314
315 /// Swaps two dimensions of a bool tensor.
316 ///
317 /// # Arguments
318 ///
319 /// * `tensor` - The tensor to swap the dimensions of.
320 /// * `dim1` - The first dimension to swap.
321 /// * `dim2` - The second dimension to swap.
322 ///
323 /// # Returns
324 ///
325 /// The tensor with the dimensions swapped.
326 fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;
327
328 /// Permutes the dimensions of a tensor.
329 ///
330 /// # Arguments
331 ///
332 /// * `tensor` - The tensor to permute the dimensions of.
333 /// * `axes` - The new order of the dimensions.
334 /// # Returns
335 ///
336 /// The tensor with the dimensions permuted.
337 fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
338
339 /// Reverse the order of elements in a tensor along the given axes.
340 ///
341 /// # Arguments
342 ///
343 /// * `tensor` - The tensor to reverse.
344 /// * `axes` - The axes to reverse.
345 ///
346 /// The tensor with the elements reversed.
347 fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;
348
349 /// Tests if any element in the boolean `tensor` evaluates to True.
350 ///
351 /// # Arguments
352 ///
353 /// * `tensor` - The tensor to test.
354 ///
355 /// # Returns
356 ///
357 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
358 fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {
359 let sum = B::int_sum(B::bool_into_int(tensor));
360 B::int_greater_elem(sum, 0.elem())
361 }
362
363 /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.
364 ///
365 /// # Arguments
366 ///
367 /// * `tensor` - The tensor to test.
368 /// * `dim` - The axis along which to test.
369 ///
370 /// # Returns
371 ///
372 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
373 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
374 /// evaluates to True, False otherwise.
375 fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
376 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
377 B::int_greater_elem(sum, 0.elem())
378 }
379
380 /// Tests if all elements in the boolean `tensor` evaluate to True.
381 ///
382 /// # Arguments
383 ///
384 /// * `tensor` - The tensor to test.
385 ///
386 /// # Returns
387 ///
388 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
389 /// evaluate to True, False otherwise.
390 fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {
391 let num_elems = tensor.shape().num_elements();
392 let sum = B::int_sum(B::bool_into_int(tensor));
393 B::int_equal_elem(sum, (num_elems as i32).elem())
394 }
395
396 /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.
397 ///
398 /// # Arguments
399 ///
400 /// * `tensor` - The tensor to test.
401 /// * `dim` - The axis along which to test.
402 ///
403 /// # Returns
404 ///
405 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
406 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
407 /// evaluates to True, False otherwise.
408 fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {
409 let num_elems = tensor.shape().dims[dim];
410 let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);
411 B::int_equal_elem(sum, (num_elems as i32).elem())
412 }
413
414 /// Compute the indices of the elements that are non-zero, grouped by element.
415 ///
416 /// # Arguments
417 ///
418 /// * `tensor` - The input tensor.
419 ///
420 /// # Returns
421 ///
422 /// A 2D tensor containing the indices of all non-zero elements of the given tensor.
423 /// Each row contains the indices of a non-zero element.
424 fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {
425 async {
426 // Size of each output tensor is variable (= number of nonzero elements in the tensor).
427 // Reading the data to count the number of truth values might cause sync but is required.
428 let device = B::bool_device(&tensor);
429 let data = B::bool_into_data(tensor).await;
430 argwhere_data::<B>(data, &device)
431 }
432 }
433
434 /// Broadcasts the bool `tensor` to the given `shape`.
435 fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;
436
437 /// Unfold windows along a dimension.
438 ///
439 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
440 /// where windows are advanced by `step` at each index.
441 ///
442 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
443 ///
444 /// # Arguments
445 ///
446 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
447 /// * `dim` - the selected dim.
448 /// * `size` - the size of each unfolded window.
449 /// * `step` - the step between each window.
450 ///
451 /// # Returns
452 ///
453 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
454 fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;
455}