burn_tensor/tensor/ops/int_tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
4use crate::cast::ToElement;
5use crate::tensor::api::{chunk, narrow, split, split_with_sizes};
6use crate::{Distribution, ElementConversion, Int, TensorData, backend::Backend, tensor::Shape};
7use alloc::vec::Vec;
8use core::future::Future;
9use core::ops::Range;
10
11use crate::{TensorMetadata, argsort, sort, sort_with_indices};
12
13/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16 /// Creates a new int tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 ///
23 /// # Returns
24 ///
25 /// The integer tensor with the given shape.
26 fn int_empty(shape: Shape, device: &Device<B>) -> IntTensor<B>;
27
28 /// Converts the tensor to a data structure.
29 ///
30 /// # Arguments
31 ///
32 /// * `tensor` - The tensor.
33 ///
34 /// # Returns
35 ///
36 /// The data structure with the tensor's data.
37 fn int_into_data(tensor: IntTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
38
39 /// Creates a tensor from the data structure.
40 ///
41 /// # Arguments
42 ///
43 /// * `data` - The data structure.
44 /// * `device` - The device to create the tensor on.
45 ///
46 /// # Returns
47 ///
48 /// The tensor with the data.
49 fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
50
51 /// Gets the device of the tensor.
52 ///
53 /// # Arguments
54 ///
55 /// * `tensor` - The tensor.
56 ///
57 /// # Returns
58 ///
59 /// The device of the tensor.
60 fn int_device(tensor: &IntTensor<B>) -> Device<B>;
61
62 /// Moves the tensor to the given device.
63 fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
64
65 /// Reshapes the tensor.
66 ///
67 /// # Arguments
68 ///
69 /// * `tensor` - The tensor.
70 /// * `shape` - The new shape.
71 ///
72 /// # Returns
73 ///
74 /// The tensor with the new shape.
75 fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
76
77 /// Gets the element at the given indices.
78 ///
79 /// # Arguments
80 ///
81 /// * `tensor` - The tensor.
82 /// * `indices` - The indices.
83 ///
84 /// # Returns
85 ///
86 /// The elements at the given indices.
87 fn int_slice(tensor: IntTensor<B>, indices: &[Range<usize>]) -> IntTensor<B>;
88
89 /// Sets the element at the given indices.
90 ///
91 /// # Arguments
92 ///
93 /// * `tensor` - The tensor.
94 /// * `indices` - The indices.
95 ///
96 /// # Returns
97 ///
98 /// The tensor with the element at the given indices set.
99 fn int_slice_assign(
100 tensor: IntTensor<B>,
101 indices: &[Range<usize>],
102 value: IntTensor<B>,
103 ) -> IntTensor<B>;
104
105 /// Converts int tensor to float tensor.
106 ///
107 /// # Arguments
108 ///
109 /// * `tensor` - The tensor.
110 ///
111 /// # Returns
112 ///
113 /// The int tensor with the same data as the float tensor.
114 fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
115
116 /// Fills the tensor with values from the source tensor if the mask is true at the given
117 /// indices.
118 ///
119 /// # Arguments
120 ///
121 /// * `tensor` - The tensor.
122 /// * `mask` - The mask.
123 /// * `source` - The source tensor.
124 ///
125 /// # Returns
126 ///
127 /// The tensor with the values filled.
128 fn int_mask_where(
129 tensor: IntTensor<B>,
130 mask: BoolTensor<B>,
131 source: IntTensor<B>,
132 ) -> IntTensor<B>;
133
134 /// Fills the tensor with the given value if the mask is true at the given indices.
135 ///
136 /// # Arguments
137 ///
138 /// * `tensor` - The tensor.
139 /// * `mask` - The mask.
140 /// * `value` - The value.
141 ///
142 /// # Returns
143 ///
144 /// The tensor with the values filled.
145 fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
146
147 /// Gather elements from the tensor at the given indices.
148 ///
149 /// # Arguments
150 ///
151 /// * `dim` - The dimension to gather from.
152 /// * `tensor` - The tensor.
153 /// * `indices` - The indices.
154 fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
155
156 /// Scatter a given value to the tensor at the given indices.
157 ///
158 /// # Arguments
159 ///
160 /// * `dim` - The dimension to scatter to.
161 /// * `tensor` - The tensor.
162 /// * `indices` - The indices.
163 /// * `value` - The value.
164 ///
165 /// # Returns
166 ///
167 /// The tensor with the values scattered.
168 fn int_scatter(
169 dim: usize,
170 tensor: IntTensor<B>,
171 indices: IntTensor<B>,
172 value: IntTensor<B>,
173 ) -> IntTensor<B>;
174
175 /// Select tensor elements along the given dimension corresponding to the given indices.
176 ///
177 /// # Arguments
178 ///
179 /// * `tensor` - The tensor.
180 /// * `dim` - The dimension to select from.
181 /// * `indices` - The indices.
182 ///
183 /// # Returns
184 ///
185 /// The tensor with the selected elements.
186 fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
187
188 /// Assign the selected elements along the given dimension corresponding to the given indices
189 /// to the given value.
190 ///
191 /// # Arguments
192 ///
193 /// * `tensor` - The tensor.
194 /// * `dim` - The dimension to select from.
195 /// * `indices` - The indices.
196 /// * `value` - The value.
197 ///
198 /// # Returns
199 ///
200 /// The tensor with the selected elements assigned to the given value.
201 fn int_select_assign(
202 tensor: IntTensor<B>,
203 dim: usize,
204 indices: IntTensor<B>,
205 value: IntTensor<B>,
206 ) -> IntTensor<B>;
207
208 /// Repeats the tensor along the given dimension the given number of times.
209 ///
210 /// # Arguments
211 ///
212 /// * `tensor` - The tensor.
213 /// * `dim` - The dimension to repeat.
214 /// * `times` - The number of times to repeat.
215 ///
216 /// # Returns
217 ///
218 /// The tensor with the given dimension repeated the given number of times.
219 fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
220 repeat_with_slice_assign::<B, Int>(tensor, dim, times)
221 }
222
223 /// Concatenates the given tensors along the given dimension.
224 ///
225 /// # Arguments
226 ///
227 /// * `tensors` - The tensors.
228 /// * `dim` - The dimension to concatenate along.
229 ///
230 /// # Returns
231 ///
232 /// The concatenated tensor.
233 fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
234 cat_with_slice_assign::<B, Int>(tensors, dim)
235 }
236
237 /// Element-wise equality comparison.
238 ///
239 /// # Arguments
240 ///
241 /// * `lhs` - The left hand side tensor.
242 /// * `rhs` - The right hand side tensor.
243 ///
244 /// # Returns
245 ///
246 /// The boolean tensor with the result of the comparison.
247 fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
248
249 /// Element-wise non-equality comparison.
250 ///
251 /// # Arguments
252 ///
253 /// * `lhs` - The left hand side tensor.
254 /// * `rhs` - The right hand side tensor.
255 ///
256 /// # Returns
257 ///
258 /// The boolean tensor with the result of the comparison.
259 fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
260 let equal_tensor = B::int_equal(lhs, rhs);
261 B::bool_not(equal_tensor)
262 }
263
264 /// Element-wise equality comparison with a scalar.
265 ///
266 /// # Arguments
267 ///
268 /// * `lhs` - The left hand side tensor.
269 /// * `rhs` - The right hand side scalar.
270 ///
271 /// # Returns
272 ///
273 /// The boolean tensor with the result of the comparison.
274 fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
275
276 /// Element-wise non-equality comparison with a scalar.
277 ///
278 /// # Arguments
279 ///
280 /// * `lhs` - The left hand side tensor.
281 /// * `rhs` - The right hand side scalar.
282 ///
283 /// # Returns
284 ///
285 /// The boolean tensor with the result of the comparison.
286 fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
287 let equal_tensor = B::int_equal_elem(lhs, rhs);
288 B::bool_not(equal_tensor)
289 }
290
291 /// Element-wise greater than comparison.
292 ///
293 /// # Arguments
294 ///
295 /// * `lhs` - The left hand side tensor.
296 /// * `rhs` - The right hand side tensor.
297 ///
298 /// # Returns
299 ///
300 /// The boolean tensor with the result of the comparison.
301 fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
302
303 /// Element-wise greater than comparison with a scalar.
304 ///
305 /// # Arguments
306 ///
307 /// * `lhs` - The left hand side tensor.
308 /// * `rhs` - The right hand side scalar.
309 ///
310 /// # Returns
311 ///
312 /// The boolean tensor with the result of the comparison.
313 fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
314
315 /// Element-wise greater than or equal comparison.
316 ///
317 /// # Arguments
318 ///
319 /// * `lhs` - The left hand side tensor.
320 /// * `rhs` - The right hand side tensor.
321 ///
322 /// # Returns
323 ///
324 /// The boolean tensor with the result of the comparison.
325 fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
326
327 /// Element-wise greater than or equal comparison with a scalar.
328 ///
329 /// # Arguments
330 ///
331 /// * `lhs` - The left hand side tensor.
332 /// * `rhs` - The right hand side scalar.
333 ///
334 /// # Returns
335 ///
336 /// The boolean tensor with the result of the comparison.
337 fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
338
339 /// Element-wise less than comparison.
340 ///
341 /// # Arguments
342 ///
343 /// * `lhs` - The left hand side tensor.
344 /// * `rhs` - The right hand side tensor.
345 ///
346 /// # Returns
347 ///
348 /// The boolean tensor with the result of the comparison.
349 fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
350
351 /// Element-wise less than comparison with a scalar.
352 ///
353 /// # Arguments
354 ///
355 /// * `lhs` - The left hand side tensor.
356 /// * `rhs` - The right hand side scalar.
357 ///
358 /// # Returns
359 ///
360 /// The boolean tensor with the result of the comparison.
361 fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
362
363 /// Element-wise less than or equal comparison.
364 ///
365 /// # Arguments
366 ///
367 /// * `lhs` - The left hand side tensor.
368 /// * `rhs` - The right hand side tensor.
369 ///
370 /// # Returns
371 ///
372 /// The boolean tensor with the result of the comparison.
373 fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
374
375 /// Element-wise less than or equal comparison with a scalar.
376 ///
377 /// # Arguments
378 ///
379 /// * `lhs` - The left hand side tensor.
380 /// * `rhs` - The right hand side scalar.
381 ///
382 /// # Returns
383 ///
384 /// The boolean tensor with the result of the comparison.
385 fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
386
387 // ==== NUMERIC ==== //
388
389 /// Element-wise addition.
390 ///
391 /// # Arguments
392 ///
393 /// * `lhs` - The left hand side tensor.
394 /// * `rhs` - The right hand side tensor.
395 ///
396 /// # Returns
397 ///
398 /// The result of the addition.
399 fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
400
401 /// Element-wise addition with a scalar.
402 ///
403 /// # Arguments
404 ///
405 /// * `lhs` - The left hand side tensor.
406 /// * `rhs` - The right hand side scalar.
407 ///
408 /// # Returns
409 ///
410 /// The result of the addition.
411 fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
412
413 /// Element-wise power with a IntTensor.
414 ///
415 /// # Arguments
416 ///
417 /// * `lhs` - The left hand side IntTensor.
418 /// * `rhs` - The right hand side IntTensor.
419 ///
420 /// # Returns
421 ///
422 /// The elements of `lhs` raised to the power of the elements of `rhs`.
423 fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
424 B::float_into_int(B::float_powf(
425 B::int_into_float(lhs),
426 B::int_into_float(rhs),
427 ))
428 }
429
430 /// Element-wise power with a floatTensor.
431 ///
432 /// # Arguments
433 ///
434 /// * `lhs` - The left hand side tensor.
435 /// * `rhs` - The right hand side floatTensor.
436 ///
437 /// # Returns
438 ///
439 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
440 fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
441 B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
442 }
443
444 /// Element-wise power with a scalar.
445 ///
446 /// # Arguments
447 ///
448 /// * `lhs` - The left hand side tensor.
449 /// * `rhs` - The right hand side scalar.
450 ///
451 /// # Returns
452 ///
453 /// The elements of `lhs` raised to the value of `rhs`.
454 fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
455 B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs.to_f32()))
456 }
457
458 /// Element-wise power with a floatTensor.
459 ///
460 /// # Arguments
461 ///
462 /// * `lhs` - The left hand side tensor.
463 /// * `rhs` - The right hand side scalar.
464 ///
465 /// # Returns
466 ///
467 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
468 fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
469 B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs))
470 }
471
472 /// Clamps a tensor under a minimum value.
473 ///
474 /// # Arguments
475 ///
476 /// * `tensor` - The tensor to clamp.
477 /// * `min` - The minimum value.
478 ///
479 /// # Returns
480 ///
481 /// The clamped tensor.
482 fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
483 let mask = Self::int_lower_elem(tensor.clone(), min);
484 Self::int_mask_fill(tensor, mask, min)
485 }
486
487 /// Clamps a tensor over a maximum value.
488 ///
489 /// # Arguments
490 ///
491 /// * `tensor` - The tensor to clamp.
492 /// * `max` - The maximum value.
493 ///
494 /// # Returns
495 ///
496 /// The clamped tensor.
497 fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
498 let mask = Self::int_greater_elem(tensor.clone(), max);
499 Self::int_mask_fill(tensor, mask, max)
500 }
501
502 /// Clamps a tensor between a minimum and maximum value.
503 ///
504 /// # Arguments
505 ///
506 /// * `tensor` - The tensor to clamp.
507 /// * `min` - The minimum value.
508 /// * `max` - The maximum value.
509 ///
510 /// # Returns
511 ///
512 /// The clamped tensor.
513 fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
514 Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
515 }
516
517 /// Element-wise subtraction.
518 ///
519 /// # Arguments
520 ///
521 /// * `lhs` - The left hand side tensor.
522 /// * `rhs` - The right hand side tensor.
523 ///
524 /// # Returns
525 ///
526 /// The result of the subtraction.
527 fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
528
529 /// Element-wise subtraction with a scalar.
530 ///
531 /// # Arguments
532 ///
533 /// * `lhs` - The left hand side tensor.
534 /// * `rhs` - The right hand side scalar.
535 ///
536 /// # Returns
537 ///
538 /// The result of the subtraction.
539 fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
540
541 /// Element-wise multiplication.
542 ///
543 /// # Arguments
544 ///
545 /// * `lhs` - The left hand side tensor.
546 /// * `rhs` - The right hand side tensor.
547 ///
548 /// # Returns
549 ///
550 /// The result of the multiplication.
551 fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
552
553 /// Element-wise multiplication with a scalar.
554 ///
555 /// # Arguments
556 ///
557 /// * `lhs` - The left hand side tensor.
558 /// * `rhs` - The right hand side scalar.
559 ///
560 /// # Returns
561 ///
562 /// The result of the multiplication.
563 fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
564
565 /// Element-wise division.
566 ///
567 /// # Arguments
568 ///
569 /// * `lhs` - The left hand side tensor.
570 /// * `rhs` - The right hand side tensor.
571 ///
572 /// # Returns
573 ///
574 /// The result of the division.
575 fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
576
577 /// Element-wise division with a scalar.
578 ///
579 /// # Arguments
580 ///
581 /// * `lhs` - The left hand side tensor.
582 /// * `rhs` - The right hand side scalar.
583 ///
584 /// # Returns
585 ///
586 /// The result of the division.
587 fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
588
589 /// Element-wise modulus.
590 ///
591 /// # Arguments
592 /// * `lhs` - The left hand side tensor.
593 /// * `rhs` - The right hand side scalar.
594 ///
595 /// # Returns
596 ///
597 /// The result of applying the modulus of the scalar to the tensor.
598 fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
599
600 /// Element-wise modulus with a scalar.
601 ///
602 /// # Arguments
603 /// * `lhs` - The left hand side tensor.
604 /// * `rhs` - The right hand side scalar.
605 ///
606 /// # Returns
607 ///
608 /// The result of applying the modulus of the scalar to the tensor.
609 fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
610
611 /// Element-wise negation.
612 ///
613 /// # Arguments
614 ///
615 /// * `tensor` - The tensor to negate.
616 ///
617 /// # Returns
618 ///
619 /// The negated tensor.
620 fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
621 Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
622 }
623
624 /// Creates a tensor of zeros.
625 ///
626 /// # Arguments
627 ///
628 /// * `shape` - The shape of the tensor.
629 /// * `device` - The device to create the tensor on.
630 ///
631 /// # Returns
632 ///
633 /// The tensor of zeros.
634 fn int_zeros(shape: Shape, device: &Device<B>) -> IntTensor<B>;
635
636 /// Creates a tensor of ones.
637 ///
638 /// # Arguments
639 ///
640 /// * `shape` - The shape of the tensor.
641 /// * `device` - The device to create the tensor on.
642 ///
643 /// # Returns
644 ///
645 /// The tensor of ones.
646 fn int_ones(shape: Shape, device: &Device<B>) -> IntTensor<B>;
647
648 /// Creates a tensor filled with given value.
649 ///
650 /// # Arguments
651 ///
652 /// * `shape` - The shape of the tensor.
653 /// * `fill_value` - The value with which to fill the tensor.
654 /// * `device` - The device to create the tensor on.
655 ///
656 /// # Returns
657 ///
658 /// The tensor filled with given value
659 fn int_full(shape: Shape, fill_value: IntElem<B>, device: &Device<B>) -> IntTensor<B> {
660 Self::int_add_scalar(Self::int_zeros(shape, device), fill_value)
661 }
662
663 /// Sums all elements in the tensor.
664 ///
665 /// # Arguments
666 ///
667 /// * `tensor` - The tensor to sum.
668 ///
669 /// # Returns
670 ///
671 /// The sum of all elements in the tensor.
672 fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
673
674 /// Sums all elements in the tensor along a dimension.
675 ///
676 /// # Arguments
677 ///
678 /// * `tensor` - The tensor to sum.
679 /// * `dim` - The dimension to sum along.
680 ///
681 /// # Returns
682 ///
683 /// The sum of all elements in the tensor along the dimension.
684 fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
685
686 /// Computes the product of all elements in the tensor.
687 ///
688 /// # Arguments
689 ///
690 /// * `tensor` - The tensor to compute the product of.
691 ///
692 /// # Returns
693 ///
694 /// The product of all elements in the tensor.
695 fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
696
697 /// Computes the product of all elements in the tensor along a dimension.
698 ///
699 /// # Arguments
700 ///
701 /// * `tensor` - The tensor to compute the product of.
702 /// * `dim` - The dimension to compute the product along.
703 ///
704 /// # Returns
705 ///
706 /// The product of all elements in the tensor along the dimension.
707 fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
708
709 /// Computes the mean of all elements in the tensor.
710 ///
711 /// # Arguments
712 ///
713 /// * `tensor` - The tensor to compute the mean of.
714 ///
715 /// # Returns
716 ///
717 /// The mean of all elements in the tensor.
718 fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
719 let num_elems = tensor.shape().num_elements();
720 B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
721 }
722
723 /// Computes the mean of all elements in the tensor along a dimension.
724 ///
725 /// # Arguments
726 ///
727 /// * `tensor` - The tensor to compute the mean of.
728 ///
729 /// # Returns
730 ///
731 /// The mean of all elements in the tensor along the dimension.
732 fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
733
734 /// Gets the indices of the maximum elements along a dimension.
735 ///
736 /// # Arguments
737 ///
738 /// * `tensor` - The tensor to get the maximum indices of.
739 /// * `dim` - The dimension to get the maximum indices along.
740 ///
741 /// # Returns
742 ///
743 /// The indices of the maximum elements along the dimension.
744 fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
745
746 /// Gets the indices of the minimum elements along a dimension.
747 ///
748 /// # Arguments
749 ///
750 /// * `tensor` - The tensor to get the minimum indices of.
751 /// * `dim` - The dimension to get the minimum indices along.
752 ///
753 /// # Returns
754 ///
755 /// The indices of the minimum elements along the dimension.
756 fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
757
758 /// Gets the maximum element in the tensor.
759 ///
760 /// # Arguments
761 ///
762 /// * `tensor` - The tensor to get the maximum element of.
763 ///
764 /// # Returns
765 ///
766 /// The maximum element in the tensor.
767 fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
768 let shape = tensor.shape();
769 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
770
771 B::int_max_dim(tensor, 0)
772 }
773
774 /// Gets the maximum element in the tensor along a dimension.
775 ///
776 /// # Arguments
777 ///
778 /// * `tensor` - The tensor to get the maximum element of.
779 /// * `dim` - The dimension to get the maximum element along.
780 ///
781 /// # Returns
782 ///
783 /// The maximum element in the tensor along the dimension.
784 fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
785 let index = B::int_argmax(tensor.clone(), dim);
786 let ndim = tensor.shape().num_dims();
787
788 B::int_gather(ndim - 1, tensor, index)
789 }
790
791 /// Gets the maximum elements and corresponding indices along a dimension.
792 ///
793 /// # Arguments
794 ///
795 /// * `tensor` - The tensor to get the maximum elements and indices of.
796 /// * `dim` - The dimension to get the maximum elements and indices along.
797 ///
798 /// # Returns
799 ///
800 /// The maximum elements and corresponding indices along the dimension.
801 fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
802 let index = B::int_argmax(tensor.clone(), dim);
803 let values = B::int_gather(dim, tensor, index.clone());
804
805 (values, index)
806 }
807
808 /// Gets the maximum absolute element in the tensor.
809 ///
810 /// # Arguments
811 ///
812 /// * `tensor` - The tensor to get the maximum element of.
813 ///
814 /// # Returns
815 ///
816 /// The maximum element in the tensor.
817 fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
818 let shape = tensor.shape();
819 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
820
821 B::int_max_abs_dim(tensor, 0)
822 }
823
824 /// Gets the maximum absolute element in the tensor along a dimension.
825 ///
826 /// # Arguments
827 ///
828 /// * `tensor` - The tensor to get the maximum element of.
829 /// * `dim` - The dimension to get the maximum element along.
830 ///
831 /// # Returns
832 ///
833 /// The maximum element in the tensor along the dimension.
834 fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
835 B::int_max_dim(B::int_abs(tensor), dim)
836 }
837
838 /// Gets the minimum element in the tensor.
839 ///
840 /// # Arguments
841 ///
842 /// * `tensor` - The tensor to get the minimum element of.
843 ///
844 /// # Returns
845 ///
846 /// The minimum element in the tensor.
847 fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
848 let shape = tensor.shape();
849 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
850
851 B::int_min_dim(tensor, 0)
852 }
853
854 /// Gets the minimum elements in the tensor along a dimension.
855 ///
856 /// # Arguments
857 ///
858 /// * `tensor` - The tensor to get the minimum element of.
859 /// * `dim` - The dimension to get the minimum element along.
860 ///
861 /// # Returns
862 ///
863 /// The minimum element in the tensor along the dimension.
864 fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
865 let index = B::int_argmin(tensor.clone(), dim);
866 let ndim = tensor.shape().num_dims();
867
868 B::int_gather(ndim - 1, tensor, index)
869 }
870
871 /// Gets the minimum elements and corresponding indices along a dimension.
872 ///
873 /// # Arguments
874 ///
875 /// * `tensor` - The tensor to get the minimum elements and indices of.
876 /// * `dim` - The dimension to get the minimum elements and indices along.
877 ///
878 /// # Returns
879 ///
880 /// The minimum elements and corresponding indices along the dimension.
881 fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
882 let indices = B::int_argmin(tensor.clone(), dim);
883 let ndim = tensor.shape().num_dims();
884 let values = B::int_gather(ndim - 1, tensor, indices.clone());
885
886 (values, indices)
887 }
888
889 /// Returns a new tensor with absolute values.
890 ///
891 /// # Arguments
892 ///
893 /// * `tensor` - The tensor to take absolute value of.
894 ///
895 /// # Returns
896 ///
897 /// A tensor with the same shape as `tensor` with absolute values.
898 fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
899
900 /// Transposes an int tensor.
901 ///
902 /// # Arguments
903 ///
904 /// * `tensor` - The tensor to transpose.
905 ///
906 /// # Returns
907 ///
908 /// The transposed tensor.
909 fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
910 let ndims = tensor.shape().num_dims();
911 Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
912 }
913
914 /// Swaps two dimensions of an int tensor.
915 ///
916 /// # Arguments
917 ///
918 /// * `tensor` - The tensor to swap the dimensions of.
919 /// * `dim1` - The first dimension to swap.
920 /// * `dim2` - The second dimension to swap.
921 ///
922 /// # Returns
923 ///
924 /// The tensor with the dimensions swapped.
925 fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
926
927 /// Permutes the dimensions of a tensor.
928 ///
929 /// # Arguments
930 ///
931 /// * `tensor` - The tensor to permute the dimensions of.
932 /// * `axes` - The new order of the dimensions.
933 /// # Returns
934 ///
935 /// The tensor with the dimensions permuted.
936 fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
937
938 /// Reverse the order of elements in a tensor along the given axes.
939 ///
940 /// # Arguments
941 ///
942 /// * `tensor` - The tensor to reverse.
943 /// * `axes` - The axes to reverse.
944 ///
945 /// The tensor with the elements reversed.
946 fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
947
948 /// Returns a new tensor with the given dimension narrowed to the given range.
949 ///
950 /// # Arguments
951 ///
952 /// * `dim` - The dimension along which the tensor will be narrowed.
953 /// * `start` - The starting point of the given range.
954 /// * `length` - The ending point of the given range.
955 /// # Panics
956 ///
957 /// - If the dimension is greater than the number of dimensions of the tensor.
958 /// - If the given range exceeds the number of elements on the given dimension.
959 ///
960 /// # Returns
961 ///
962 /// A new tensor with the given dimension narrowed to the given range.
963 fn int_narrow(tensor: IntTensor<B>, dim: usize, start: usize, length: usize) -> IntTensor<B> {
964 narrow::<B, Int>(tensor, dim, start, length)
965 }
966
967 /// Split the tensor along the given dimension into chunks.
968 ///
969 /// # Arguments
970 ///
971 /// * `tensor` - The tensor.
972 /// * `chunks` - The number of chunks to be produced.
973 /// * `times` - The dimension along which the tensor will be split.
974 ///
975 /// # Returns
976 ///
977 /// A vector of tensors
978 fn int_chunk(tensor: IntTensor<B>, chunks: usize, dim: usize) -> Vec<IntTensor<B>> {
979 chunk::<B, Int>(tensor, chunks, dim)
980 }
981
982 /// Split the tensor along the given dimension into chunks of `split_size`.
983 ///
984 /// # Arguments
985 ///
986 /// * `tensor` - The tensor.
987 /// * `split_size` - The size of a single chunk.
988 /// * `times` - The dimension along which the tensor will be split.
989 ///
990 /// # Returns
991 ///
992 /// A vector of tensors.
993 fn int_split(tensor: IntTensor<B>, split_size: usize, dim: usize) -> Vec<IntTensor<B>> {
994 split::<B, Int>(tensor, split_size, dim)
995 }
996
997 /// Split the tensor along the given dimension into chunks with sizes in
998 /// `dim` according to `split_sizes`.
999 ///
1000 /// # Arguments
1001 ///
1002 /// * `tensor` - The tensor.
1003 /// * `split_sizes` - Vector of sizes for each chunk.
1004 /// * `times` - The dimension along which the tensor will be split.
1005 ///
1006 /// # Returns
1007 ///
1008 /// A vector of tensors.
1009 fn int_split_with_sizes(
1010 tensor: IntTensor<B>,
1011 split_sizes: Vec<usize>,
1012 dim: usize,
1013 ) -> Vec<IntTensor<B>> {
1014 split_with_sizes::<B, Int>(tensor, split_sizes, dim)
1015 }
1016
1017 /// Creates a new int tensor with random values.
1018 ///
1019 /// # Arguments
1020 /// * `shape` - The shape of the tensor.
1021 /// * `distribution` - The distribution to sample from.
1022 /// * `device` - The device to create the tensor on.
1023 ///
1024 /// # Returns
1025 ///
1026 /// The tensor with the given shape and random values.
1027 fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
1028
1029 /// Creates a new tensor with values from the given range with the given step size.
1030 ///
1031 /// # Arguments
1032 ///
1033 /// * `range` - The range of values.
1034 /// * `step` - The step size.
1035 /// * `device` - The device to create the tensor on.
1036 ///
1037 /// # Returns
1038 ///
1039 /// The tensor with the given values.
1040 fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1041 let value = range
1042 .step_by(step)
1043 .map(|i| i.elem())
1044 .collect::<Vec<IntElem<B>>>();
1045 let shape = Shape::new([value.len()]);
1046 let data = TensorData::new(value, shape);
1047 B::int_from_data(data, device)
1048 }
1049
1050 /// Creates a new tensor with values from the given range.
1051 ///
1052 /// # Arguments
1053 ///
1054 /// * `range` - The range of values.
1055 /// * `device` - The device to create the tensor on.
1056 ///
1057 /// # Returns
1058 ///
1059 /// The tensor with the given values.
1060 ///
1061 /// # Remarks
1062 ///
1063 /// Uses `arange_step` with a step size of 1 under the hood.
1064 fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1065 Self::int_arange_step(range, 1, device)
1066 }
1067
1068 /// Tests if any element in the int `tensor` evaluates to True.
1069 ///
1070 /// # Arguments
1071 ///
1072 /// * `tensor` - The tensor to test.
1073 ///
1074 /// # Returns
1075 ///
1076 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1077 fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1078 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1079 let bool_tensor = B::bool_not(bool_tensor);
1080 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1081 B::int_greater_elem(sum, 0.elem())
1082 }
1083
1084 /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1085 ///
1086 /// # Arguments
1087 ///
1088 /// * `tensor` - The tensor to test.
1089 /// * `dim` - The axis along which to test.
1090 ///
1091 /// # Returns
1092 ///
1093 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1094 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1095 /// evaluates to True, False otherwise.
1096 fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1097 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1098 let bool_tensor = B::bool_not(bool_tensor);
1099 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1100 B::int_greater_elem(sum, 0.elem())
1101 }
1102
1103 /// Tests if all elements in the int `tensor` evaluate to True.
1104 ///
1105 /// # Arguments
1106 ///
1107 /// * `tensor` - The tensor to test.
1108 ///
1109 /// # Returns
1110 ///
1111 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1112 /// evaluate to True, False otherwise.
1113 fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1114 let num_elems = tensor.shape().num_elements();
1115 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1116 let bool_tensor = B::bool_not(bool_tensor);
1117 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1118 B::int_equal_elem(sum, (num_elems as i32).elem())
1119 }
1120
1121 /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1122 ///
1123 /// # Arguments
1124 ///
1125 /// * `tensor` - The tensor to test.
1126 /// * `dim` - The axis along which to test.
1127 ///
1128 /// # Returns
1129 ///
1130 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1131 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1132 /// evaluates to True, False otherwise.
1133 fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1134 let num_elems = tensor.shape().dims[dim];
1135 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1136 let bool_tensor = B::bool_not(bool_tensor);
1137 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1138 B::int_equal_elem(sum, (num_elems as i32).elem())
1139 }
1140
1141 /// Returns the signs of the int `tensor`.
1142 ///
1143 /// # Arguments
1144 ///
1145 /// * `tensor` - The tensor to extract the signs from.
1146 ///
1147 /// # Returns
1148 ///
1149 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1150 fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1151 let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor));
1152 let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1153 let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1154
1155 let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1156 result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1157 result
1158 }
1159
1160 /// Broadcasts the int `tensor` to the given `shape`.
1161 fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1162
1163 /// Sort the elements of the input `tensor` by value along a given dimension.
1164 ///
1165 /// This sort is unstable (i.e., may reorder equal elements).
1166 ///
1167 /// # Arguments
1168 ///
1169 /// * `tensor` - The input tensor.
1170 /// * `dim` - The axis along which to sort.
1171 /// * `descending` - The sorting order.
1172 ///
1173 /// # Returns
1174 ///
1175 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1176 fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1177 sort::<B, Int>(tensor, dim, descending)
1178 }
1179
1180 /// Sort the elements of the input `tensor` by value along a given dimension.
1181 ///
1182 /// This sort is unstable (i.e., may reorder equal elements).
1183 ///
1184 /// # Arguments
1185 ///
1186 /// * `tensor` - The input tensor.
1187 /// * `dim` - The axis along which to sort.
1188 ///
1189 /// # Returns
1190 ///
1191 /// A tensor with the same shape as the input tensor and corresponding indices, where
1192 /// the elements are sorted by value and the indices map back to the original input tensor.
1193 fn int_sort_with_indices(
1194 tensor: IntTensor<B>,
1195 dim: usize,
1196 descending: bool,
1197 ) -> (IntTensor<B>, IntTensor<B>) {
1198 sort_with_indices::<B, Int>(tensor, dim, descending)
1199 }
1200
1201 /// Returns the indices that sort the elements of the input `tensor` by value
1202 /// along a given dimension.
1203 ///
1204 /// This sort is unstable (i.e., may reorder equal elements).
1205 ///
1206 /// # Arguments
1207 ///
1208 /// * `tensor` - The input tensor.
1209 /// * `dim` - The axis along which to sort.
1210 /// * `descending` - The sorting order.
1211 ///
1212 /// # Returns
1213 ///
1214 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1215 fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1216 argsort::<B, Int>(tensor, dim, descending)
1217 }
1218
1219 /// Bitwise AND operation for Int Tensors
1220 fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1221
1222 /// Bitwise AND operation for Int Tensors with a scalar
1223 fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1224
1225 /// Bitwise OR operation for Int Tensors
1226 fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1227
1228 /// Bitwise OR operation for Int Tensors with a scalar
1229 fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1230
1231 /// Bitwise XOR operation for Int Tensors
1232 fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1233
1234 /// Bitwise XOR operation for Int Tensors with a scalar
1235 fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1236
1237 /// Bitwise NOT operation for Int Tensors
1238 fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1239
1240 /// Bitwise left shift operation for Int Tensors
1241 fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1242
1243 /// Bitwise left shift operation for Int Tensors with a scalar
1244 fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1245
1246 /// Bitwise right shift operation for Int Tensors
1247 fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1248
1249 /// Bitwise right shift operation for Int Tensors with a scalar
1250 fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1251}