burn_tensor/tensor/ops/int_tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
4use crate::cast::ToElement;
5use crate::tensor::api::{chunk, narrow, split, split_with_sizes};
6use crate::{Distribution, ElementConversion, Int, TensorData, backend::Backend, tensor::Shape};
7use alloc::vec::Vec;
8use core::future::Future;
9use core::ops::Range;
10
11use crate::{TensorMetadata, argsort, sort, sort_with_indices};
12
13/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16 /// Creates a new int tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 ///
23 /// # Returns
24 ///
25 /// The integer tensor with the given shape.
26 fn int_empty(shape: Shape, device: &Device<B>) -> IntTensor<B>;
27
28 /// Converts the tensor to a data structure.
29 ///
30 /// # Arguments
31 ///
32 /// * `tensor` - The tensor.
33 ///
34 /// # Returns
35 ///
36 /// The data structure with the tensor's data.
37 fn int_into_data(tensor: IntTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
38
39 /// Creates a tensor from the data structure.
40 ///
41 /// # Arguments
42 ///
43 /// * `data` - The data structure.
44 /// * `device` - The device to create the tensor on.
45 ///
46 /// # Returns
47 ///
48 /// The tensor with the data.
49 fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
50
51 /// Gets the device of the tensor.
52 ///
53 /// # Arguments
54 ///
55 /// * `tensor` - The tensor.
56 ///
57 /// # Returns
58 ///
59 /// The device of the tensor.
60 fn int_device(tensor: &IntTensor<B>) -> Device<B>;
61
62 /// Moves the tensor to the given device.
63 fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
64
65 /// Reshapes the tensor.
66 ///
67 /// # Arguments
68 ///
69 /// * `tensor` - The tensor.
70 /// * `shape` - The new shape.
71 ///
72 /// # Returns
73 ///
74 /// The tensor with the new shape.
75 fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
76
77 /// Gets the element at the given indices.
78 ///
79 /// # Arguments
80 ///
81 /// * `tensor` - The tensor.
82 /// * `indices` - The indices.
83 ///
84 /// # Returns
85 ///
86 /// The elements at the given indices.
87 fn int_slice(tensor: IntTensor<B>, indices: &[Range<usize>]) -> IntTensor<B>;
88
89 /// Sets the element at the given indices.
90 ///
91 /// # Arguments
92 ///
93 /// * `tensor` - The tensor.
94 /// * `indices` - The indices.
95 ///
96 /// # Returns
97 ///
98 /// The tensor with the element at the given indices set.
99 fn int_slice_assign(
100 tensor: IntTensor<B>,
101 indices: &[Range<usize>],
102 value: IntTensor<B>,
103 ) -> IntTensor<B>;
104
105 /// Converts int tensor to float tensor.
106 ///
107 /// # Arguments
108 ///
109 /// * `tensor` - The tensor.
110 ///
111 /// # Returns
112 ///
113 /// The int tensor with the same data as the float tensor.
114 fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
115
116 /// Fills the tensor with values from the source tensor if the mask is true at the given
117 /// indices.
118 ///
119 /// # Arguments
120 ///
121 /// * `tensor` - The tensor.
122 /// * `mask` - The mask.
123 /// * `source` - The source tensor.
124 ///
125 /// # Returns
126 ///
127 /// The tensor with the values filled.
128 fn int_mask_where(
129 tensor: IntTensor<B>,
130 mask: BoolTensor<B>,
131 source: IntTensor<B>,
132 ) -> IntTensor<B>;
133
134 /// Fills the tensor with the given value if the mask is true at the given indices.
135 ///
136 /// # Arguments
137 ///
138 /// * `tensor` - The tensor.
139 /// * `mask` - The mask.
140 /// * `value` - The value.
141 ///
142 /// # Returns
143 ///
144 /// The tensor with the values filled.
145 fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
146
147 /// Gather elements from the tensor at the given indices.
148 ///
149 /// # Arguments
150 ///
151 /// * `dim` - The dimension to gather from.
152 /// * `tensor` - The tensor.
153 /// * `indices` - The indices.
154 fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
155
156 /// Scatter a given value to the tensor at the given indices.
157 ///
158 /// # Arguments
159 ///
160 /// * `dim` - The dimension to scatter to.
161 /// * `tensor` - The tensor.
162 /// * `indices` - The indices.
163 /// * `value` - The value.
164 ///
165 /// # Returns
166 ///
167 /// The tensor with the values scattered.
168 fn int_scatter(
169 dim: usize,
170 tensor: IntTensor<B>,
171 indices: IntTensor<B>,
172 value: IntTensor<B>,
173 ) -> IntTensor<B>;
174
175 /// Select tensor elements along the given dimension corresponding to the given indices.
176 ///
177 /// # Arguments
178 ///
179 /// * `tensor` - The tensor.
180 /// * `dim` - The dimension to select from.
181 /// * `indices` - The indices.
182 ///
183 /// # Returns
184 ///
185 /// The tensor with the selected elements.
186 fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
187
188 /// Assign the selected elements along the given dimension corresponding to the given indices
189 /// to the given value.
190 ///
191 /// # Arguments
192 ///
193 /// * `tensor` - The tensor.
194 /// * `dim` - The dimension to select from.
195 /// * `indices` - The indices.
196 /// * `value` - The value.
197 ///
198 /// # Returns
199 ///
200 /// The tensor with the selected elements assigned to the given value.
201 fn int_select_assign(
202 tensor: IntTensor<B>,
203 dim: usize,
204 indices: IntTensor<B>,
205 value: IntTensor<B>,
206 ) -> IntTensor<B>;
207
208 /// Repeats the tensor along the given dimension the given number of times.
209 ///
210 /// # Arguments
211 ///
212 /// * `tensor` - The tensor.
213 /// * `dim` - The dimension to repeat.
214 /// * `times` - The number of times to repeat.
215 ///
216 /// # Returns
217 ///
218 /// The tensor with the given dimension repeated the given number of times.
219 fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
220 repeat_with_slice_assign::<B, Int>(tensor, dim, times)
221 }
222
223 /// Concatenates the given tensors along the given dimension.
224 ///
225 /// # Arguments
226 ///
227 /// * `tensors` - The tensors.
228 /// * `dim` - The dimension to concatenate along.
229 ///
230 /// # Returns
231 ///
232 /// The concatenated tensor.
233 fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
234 cat_with_slice_assign::<B, Int>(tensors, dim)
235 }
236
237 /// Element-wise equality comparison.
238 ///
239 /// # Arguments
240 ///
241 /// * `lhs` - The left hand side tensor.
242 /// * `rhs` - The right hand side tensor.
243 ///
244 /// # Returns
245 ///
246 /// The boolean tensor with the result of the comparison.
247 fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
248
249 /// Element-wise non-equality comparison.
250 ///
251 /// # Arguments
252 ///
253 /// * `lhs` - The left hand side tensor.
254 /// * `rhs` - The right hand side tensor.
255 ///
256 /// # Returns
257 ///
258 /// The boolean tensor with the result of the comparison.
259 fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
260 let equal_tensor = B::int_equal(lhs, rhs);
261 B::bool_not(equal_tensor)
262 }
263
264 /// Element-wise equality comparison with a scalar.
265 ///
266 /// # Arguments
267 ///
268 /// * `lhs` - The left hand side tensor.
269 /// * `rhs` - The right hand side scalar.
270 ///
271 /// # Returns
272 ///
273 /// The boolean tensor with the result of the comparison.
274 fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
275
276 /// Element-wise non-equality comparison with a scalar.
277 ///
278 /// # Arguments
279 ///
280 /// * `lhs` - The left hand side tensor.
281 /// * `rhs` - The right hand side scalar.
282 ///
283 /// # Returns
284 ///
285 /// The boolean tensor with the result of the comparison.
286 fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
287 let equal_tensor = B::int_equal_elem(lhs, rhs);
288 B::bool_not(equal_tensor)
289 }
290
291 /// Element-wise greater than comparison.
292 ///
293 /// # Arguments
294 ///
295 /// * `lhs` - The left hand side tensor.
296 /// * `rhs` - The right hand side tensor.
297 ///
298 /// # Returns
299 ///
300 /// The boolean tensor with the result of the comparison.
301 fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
302
303 /// Element-wise greater than comparison with a scalar.
304 ///
305 /// # Arguments
306 ///
307 /// * `lhs` - The left hand side tensor.
308 /// * `rhs` - The right hand side scalar.
309 ///
310 /// # Returns
311 ///
312 /// The boolean tensor with the result of the comparison.
313 fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
314
315 /// Element-wise greater than or equal comparison.
316 ///
317 /// # Arguments
318 ///
319 /// * `lhs` - The left hand side tensor.
320 /// * `rhs` - The right hand side tensor.
321 ///
322 /// # Returns
323 ///
324 /// The boolean tensor with the result of the comparison.
325 fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
326
327 /// Element-wise greater than or equal comparison with a scalar.
328 ///
329 /// # Arguments
330 ///
331 /// * `lhs` - The left hand side tensor.
332 /// * `rhs` - The right hand side scalar.
333 ///
334 /// # Returns
335 ///
336 /// The boolean tensor with the result of the comparison.
337 fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
338
339 /// Element-wise less than comparison.
340 ///
341 /// # Arguments
342 ///
343 /// * `lhs` - The left hand side tensor.
344 /// * `rhs` - The right hand side tensor.
345 ///
346 /// # Returns
347 ///
348 /// The boolean tensor with the result of the comparison.
349 fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
350
351 /// Element-wise less than comparison with a scalar.
352 ///
353 /// # Arguments
354 ///
355 /// * `lhs` - The left hand side tensor.
356 /// * `rhs` - The right hand side scalar.
357 ///
358 /// # Returns
359 ///
360 /// The boolean tensor with the result of the comparison.
361 fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
362
363 /// Element-wise less than or equal comparison.
364 ///
365 /// # Arguments
366 ///
367 /// * `lhs` - The left hand side tensor.
368 /// * `rhs` - The right hand side tensor.
369 ///
370 /// # Returns
371 ///
372 /// The boolean tensor with the result of the comparison.
373 fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
374
375 /// Element-wise less than or equal comparison with a scalar.
376 ///
377 /// # Arguments
378 ///
379 /// * `lhs` - The left hand side tensor.
380 /// * `rhs` - The right hand side scalar.
381 ///
382 /// # Returns
383 ///
384 /// The boolean tensor with the result of the comparison.
385 fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
386
387 // ==== NUMERIC ==== //
388
389 /// Element-wise addition.
390 ///
391 /// # Arguments
392 ///
393 /// * `lhs` - The left hand side tensor.
394 /// * `rhs` - The right hand side tensor.
395 ///
396 /// # Returns
397 ///
398 /// The result of the addition.
399 fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
400
401 /// Element-wise addition with a scalar.
402 ///
403 /// # Arguments
404 ///
405 /// * `lhs` - The left hand side tensor.
406 /// * `rhs` - The right hand side scalar.
407 ///
408 /// # Returns
409 ///
410 /// The result of the addition.
411 fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
412
413 /// Element-wise power with a IntTensor.
414 ///
415 /// # Arguments
416 ///
417 /// * `lhs` - The left hand side IntTensor.
418 /// * `rhs` - The right hand side IntTensor.
419 ///
420 /// # Returns
421 ///
422 /// The elements of `lhs` raised to the power of the elements of `rhs`.
423 fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
424 B::float_into_int(B::float_powf(
425 B::int_into_float(lhs),
426 B::int_into_float(rhs),
427 ))
428 }
429
430 /// Element-wise power with a floatTensor.
431 ///
432 /// # Arguments
433 ///
434 /// * `lhs` - The left hand side tensor.
435 /// * `rhs` - The right hand side floatTensor.
436 ///
437 /// # Returns
438 ///
439 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
440 fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
441 B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
442 }
443
444 /// Element-wise power with a scalar.
445 ///
446 /// # Arguments
447 ///
448 /// * `lhs` - The left hand side tensor.
449 /// * `rhs` - The right hand side scalar.
450 ///
451 /// # Returns
452 ///
453 /// The elements of `lhs` raised to the value of `rhs`.
454 fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
455 B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs.to_f32()))
456 }
457
458 /// Element-wise power with a floatTensor.
459 ///
460 /// # Arguments
461 ///
462 /// * `lhs` - The left hand side tensor.
463 /// * `rhs` - The right hand side scalar.
464 ///
465 /// # Returns
466 ///
467 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
468 fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
469 B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs))
470 }
471
472 /// Clamps a tensor under a minimum value.
473 ///
474 /// # Arguments
475 ///
476 /// * `tensor` - The tensor to clamp.
477 /// * `min` - The minimum value.
478 ///
479 /// # Returns
480 ///
481 /// The clamped tensor.
482 fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
483 let mask = Self::int_lower_elem(tensor.clone(), min);
484 Self::int_mask_fill(tensor, mask, min)
485 }
486
487 /// Clamps a tensor over a maximum value.
488 ///
489 /// # Arguments
490 ///
491 /// * `tensor` - The tensor to clamp.
492 /// * `max` - The maximum value.
493 ///
494 /// # Returns
495 ///
496 /// The clamped tensor.
497 fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
498 let mask = Self::int_greater_elem(tensor.clone(), max);
499 Self::int_mask_fill(tensor, mask, max)
500 }
501
502 /// Clamps a tensor between a minimum and maximum value.
503 ///
504 /// # Arguments
505 ///
506 /// * `tensor` - The tensor to clamp.
507 /// * `min` - The minimum value.
508 /// * `max` - The maximum value.
509 ///
510 /// # Returns
511 ///
512 /// The clamped tensor.
513 fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
514 Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
515 }
516
517 /// Element-wise subtraction.
518 ///
519 /// # Arguments
520 ///
521 /// * `lhs` - The left hand side tensor.
522 /// * `rhs` - The right hand side tensor.
523 ///
524 /// # Returns
525 ///
526 /// The result of the subtraction.
527 fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
528
529 /// Element-wise subtraction with a scalar.
530 ///
531 /// # Arguments
532 ///
533 /// * `lhs` - The left hand side tensor.
534 /// * `rhs` - The right hand side scalar.
535 ///
536 /// # Returns
537 ///
538 /// The result of the subtraction.
539 fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
540
541 /// Element-wise multiplication.
542 ///
543 /// # Arguments
544 ///
545 /// * `lhs` - The left hand side tensor.
546 /// * `rhs` - The right hand side tensor.
547 ///
548 /// # Returns
549 ///
550 /// The result of the multiplication.
551 fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
552
553 /// Element-wise multiplication with a scalar.
554 ///
555 /// # Arguments
556 ///
557 /// * `lhs` - The left hand side tensor.
558 /// * `rhs` - The right hand side scalar.
559 ///
560 /// # Returns
561 ///
562 /// The result of the multiplication.
563 fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
564
565 /// Element-wise division.
566 ///
567 /// # Arguments
568 ///
569 /// * `lhs` - The left hand side tensor.
570 /// * `rhs` - The right hand side tensor.
571 ///
572 /// # Returns
573 ///
574 /// The result of the division.
575 fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
576
577 /// Element-wise division with a scalar.
578 ///
579 /// # Arguments
580 ///
581 /// * `lhs` - The left hand side tensor.
582 /// * `rhs` - The right hand side scalar.
583 ///
584 /// # Returns
585 ///
586 /// The result of the division.
587 fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
588
589 /// Element-wise modulus.
590 ///
591 /// # Arguments
592 /// * `lhs` - The left hand side tensor.
593 /// * `rhs` - The right hand side scalar.
594 ///
595 /// # Returns
596 ///
597 /// The result of applying the modulus of the scalar to the tensor.
598 fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
599
600 /// Element-wise modulus with a scalar.
601 ///
602 /// # Arguments
603 /// * `lhs` - The left hand side tensor.
604 /// * `rhs` - The right hand side scalar.
605 ///
606 /// # Returns
607 ///
608 /// The result of applying the modulus of the scalar to the tensor.
609 fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
610
611 /// Element-wise negation.
612 ///
613 /// # Arguments
614 ///
615 /// * `tensor` - The tensor to negate.
616 ///
617 /// # Returns
618 ///
619 /// The negated tensor.
620 fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
621 Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
622 }
623
624 /// Creates a tensor of zeros.
625 ///
626 /// # Arguments
627 ///
628 /// * `shape` - The shape of the tensor.
629 /// * `device` - The device to create the tensor on.
630 ///
631 /// # Returns
632 ///
633 /// The tensor of zeros.
634 fn int_zeros(shape: Shape, device: &Device<B>) -> IntTensor<B>;
635
636 /// Creates a tensor of ones.
637 ///
638 /// # Arguments
639 ///
640 /// * `shape` - The shape of the tensor.
641 /// * `device` - The device to create the tensor on.
642 ///
643 /// # Returns
644 ///
645 /// The tensor of ones.
646 fn int_ones(shape: Shape, device: &Device<B>) -> IntTensor<B>;
647
648 /// Creates a tensor filled with given value.
649 ///
650 /// # Arguments
651 ///
652 /// * `shape` - The shape of the tensor.
653 /// * `fill_value` - The value with which to fill the tensor.
654 /// * `device` - The device to create the tensor on.
655 ///
656 /// # Returns
657 ///
658 /// The tensor filled with given value
659 fn int_full(shape: Shape, fill_value: IntElem<B>, device: &Device<B>) -> IntTensor<B> {
660 Self::int_add_scalar(Self::int_zeros(shape, device), fill_value)
661 }
662
663 /// Sums all elements in the tensor.
664 ///
665 /// # Arguments
666 ///
667 /// * `tensor` - The tensor to sum.
668 ///
669 /// # Returns
670 ///
671 /// The sum of all elements in the tensor.
672 fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
673
674 /// Sums all elements in the tensor along a dimension.
675 ///
676 /// # Arguments
677 ///
678 /// * `tensor` - The tensor to sum.
679 /// * `dim` - The dimension to sum along.
680 ///
681 /// # Returns
682 ///
683 /// The sum of all elements in the tensor along the dimension.
684 fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
685
686 /// Computes the product of all elements in the tensor.
687 ///
688 /// # Arguments
689 ///
690 /// * `tensor` - The tensor to compute the product of.
691 ///
692 /// # Returns
693 ///
694 /// The product of all elements in the tensor.
695 fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
696
697 /// Computes the product of all elements in the tensor along a dimension.
698 ///
699 /// # Arguments
700 ///
701 /// * `tensor` - The tensor to compute the product of.
702 /// * `dim` - The dimension to compute the product along.
703 ///
704 /// # Returns
705 ///
706 /// The product of all elements in the tensor along the dimension.
707 fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
708
709 /// Computes the mean of all elements in the tensor.
710 ///
711 /// # Arguments
712 ///
713 /// * `tensor` - The tensor to compute the mean of.
714 ///
715 /// # Returns
716 ///
717 /// The mean of all elements in the tensor.
718 fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
719 let num_elems = tensor.shape().num_elements();
720 B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
721 }
722
723 /// Computes the mean of all elements in the tensor along a dimension.
724 ///
725 /// # Arguments
726 ///
727 /// * `tensor` - The tensor to compute the mean of.
728 ///
729 /// # Returns
730 ///
731 /// The mean of all elements in the tensor along the dimension.
732 fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
733
734 /// Gets the indices of the maximum elements along a dimension.
735 ///
736 /// # Arguments
737 ///
738 /// * `tensor` - The tensor to get the maximum indices of.
739 /// * `dim` - The dimension to get the maximum indices along.
740 ///
741 /// # Returns
742 ///
743 /// The indices of the maximum elements along the dimension.
744 fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
745
746 /// Gets the indices of the minimum elements along a dimension.
747 ///
748 /// # Arguments
749 ///
750 /// * `tensor` - The tensor to get the minimum indices of.
751 /// * `dim` - The dimension to get the minimum indices along.
752 ///
753 /// # Returns
754 ///
755 /// The indices of the minimum elements along the dimension.
756 fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
757
758 /// Gets the maximum element in the tensor.
759 ///
760 /// # Arguments
761 ///
762 /// * `tensor` - The tensor to get the maximum element of.
763 ///
764 /// # Returns
765 ///
766 /// The maximum element in the tensor.
767 fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
768 let shape = tensor.shape();
769 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
770
771 B::int_max_dim(tensor, 0)
772 }
773
774 /// Gets the maximum element in the tensor along a dimension.
775 ///
776 /// # Arguments
777 ///
778 /// * `tensor` - The tensor to get the maximum element of.
779 /// * `dim` - The dimension to get the maximum element along.
780 ///
781 /// # Returns
782 ///
783 /// The maximum element in the tensor along the dimension.
784 fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
785 let index = B::int_argmax(tensor.clone(), dim);
786 B::int_gather(dim, tensor, index)
787 }
788
789 /// Gets the maximum elements and corresponding indices along a dimension.
790 ///
791 /// # Arguments
792 ///
793 /// * `tensor` - The tensor to get the maximum elements and indices of.
794 /// * `dim` - The dimension to get the maximum elements and indices along.
795 ///
796 /// # Returns
797 ///
798 /// The maximum elements and corresponding indices along the dimension.
799 fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
800 let index = B::int_argmax(tensor.clone(), dim);
801 let values = B::int_gather(dim, tensor, index.clone());
802
803 (values, index)
804 }
805
806 /// Gets the maximum absolute element in the tensor.
807 ///
808 /// # Arguments
809 ///
810 /// * `tensor` - The tensor to get the maximum element of.
811 ///
812 /// # Returns
813 ///
814 /// The maximum element in the tensor.
815 fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
816 let shape = tensor.shape();
817 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
818
819 B::int_max_abs_dim(tensor, 0)
820 }
821
822 /// Gets the maximum absolute element in the tensor along a dimension.
823 ///
824 /// # Arguments
825 ///
826 /// * `tensor` - The tensor to get the maximum element of.
827 /// * `dim` - The dimension to get the maximum element along.
828 ///
829 /// # Returns
830 ///
831 /// The maximum element in the tensor along the dimension.
832 fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
833 B::int_max_dim(B::int_abs(tensor), dim)
834 }
835
836 /// Gets the minimum element in the tensor.
837 ///
838 /// # Arguments
839 ///
840 /// * `tensor` - The tensor to get the minimum element of.
841 ///
842 /// # Returns
843 ///
844 /// The minimum element in the tensor.
845 fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
846 let shape = tensor.shape();
847 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
848
849 B::int_min_dim(tensor, 0)
850 }
851
852 /// Gets the minimum elements in the tensor along a dimension.
853 ///
854 /// # Arguments
855 ///
856 /// * `tensor` - The tensor to get the minimum element of.
857 /// * `dim` - The dimension to get the minimum element along.
858 ///
859 /// # Returns
860 ///
861 /// The minimum element in the tensor along the dimension.
862 fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
863 let index = B::int_argmin(tensor.clone(), dim);
864 B::int_gather(dim, tensor, index)
865 }
866
867 /// Gets the minimum elements and corresponding indices along a dimension.
868 ///
869 /// # Arguments
870 ///
871 /// * `tensor` - The tensor to get the minimum elements and indices of.
872 /// * `dim` - The dimension to get the minimum elements and indices along.
873 ///
874 /// # Returns
875 ///
876 /// The minimum elements and corresponding indices along the dimension.
877 fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
878 let indices = B::int_argmin(tensor.clone(), dim);
879 let values = B::int_gather(dim, tensor, indices.clone());
880
881 (values, indices)
882 }
883
884 /// Returns a new tensor with absolute values.
885 ///
886 /// # Arguments
887 ///
888 /// * `tensor` - The tensor to take absolute value of.
889 ///
890 /// # Returns
891 ///
892 /// A tensor with the same shape as `tensor` with absolute values.
893 fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
894
895 /// Transposes an int tensor.
896 ///
897 /// # Arguments
898 ///
899 /// * `tensor` - The tensor to transpose.
900 ///
901 /// # Returns
902 ///
903 /// The transposed tensor.
904 fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
905 let ndims = tensor.shape().num_dims();
906 Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
907 }
908
909 /// Swaps two dimensions of an int tensor.
910 ///
911 /// # Arguments
912 ///
913 /// * `tensor` - The tensor to swap the dimensions of.
914 /// * `dim1` - The first dimension to swap.
915 /// * `dim2` - The second dimension to swap.
916 ///
917 /// # Returns
918 ///
919 /// The tensor with the dimensions swapped.
920 fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
921
922 /// Permutes the dimensions of a tensor.
923 ///
924 /// # Arguments
925 ///
926 /// * `tensor` - The tensor to permute the dimensions of.
927 /// * `axes` - The new order of the dimensions.
928 /// # Returns
929 ///
930 /// The tensor with the dimensions permuted.
931 fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
932
933 /// Reverse the order of elements in a tensor along the given axes.
934 ///
935 /// # Arguments
936 ///
937 /// * `tensor` - The tensor to reverse.
938 /// * `axes` - The axes to reverse.
939 ///
940 /// The tensor with the elements reversed.
941 fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
942
943 /// Returns a new tensor with the given dimension narrowed to the given range.
944 ///
945 /// # Arguments
946 ///
947 /// * `dim` - The dimension along which the tensor will be narrowed.
948 /// * `start` - The starting point of the given range.
949 /// * `length` - The ending point of the given range.
950 /// # Panics
951 ///
952 /// - If the dimension is greater than the number of dimensions of the tensor.
953 /// - If the given range exceeds the number of elements on the given dimension.
954 ///
955 /// # Returns
956 ///
957 /// A new tensor with the given dimension narrowed to the given range.
958 fn int_narrow(tensor: IntTensor<B>, dim: usize, start: usize, length: usize) -> IntTensor<B> {
959 narrow::<B, Int>(tensor, dim, start, length)
960 }
961
962 /// Split the tensor along the given dimension into chunks.
963 ///
964 /// # Arguments
965 ///
966 /// * `tensor` - The tensor.
967 /// * `chunks` - The number of chunks to be produced.
968 /// * `times` - The dimension along which the tensor will be split.
969 ///
970 /// # Returns
971 ///
972 /// A vector of tensors
973 fn int_chunk(tensor: IntTensor<B>, chunks: usize, dim: usize) -> Vec<IntTensor<B>> {
974 chunk::<B, Int>(tensor, chunks, dim)
975 }
976
977 /// Split the tensor along the given dimension into chunks of `split_size`.
978 ///
979 /// # Arguments
980 ///
981 /// * `tensor` - The tensor.
982 /// * `split_size` - The size of a single chunk.
983 /// * `times` - The dimension along which the tensor will be split.
984 ///
985 /// # Returns
986 ///
987 /// A vector of tensors.
988 fn int_split(tensor: IntTensor<B>, split_size: usize, dim: usize) -> Vec<IntTensor<B>> {
989 split::<B, Int>(tensor, split_size, dim)
990 }
991
992 /// Split the tensor along the given dimension into chunks with sizes in
993 /// `dim` according to `split_sizes`.
994 ///
995 /// # Arguments
996 ///
997 /// * `tensor` - The tensor.
998 /// * `split_sizes` - Vector of sizes for each chunk.
999 /// * `times` - The dimension along which the tensor will be split.
1000 ///
1001 /// # Returns
1002 ///
1003 /// A vector of tensors.
1004 fn int_split_with_sizes(
1005 tensor: IntTensor<B>,
1006 split_sizes: Vec<usize>,
1007 dim: usize,
1008 ) -> Vec<IntTensor<B>> {
1009 split_with_sizes::<B, Int>(tensor, split_sizes, dim)
1010 }
1011
1012 /// Creates a new int tensor with random values.
1013 ///
1014 /// # Arguments
1015 /// * `shape` - The shape of the tensor.
1016 /// * `distribution` - The distribution to sample from.
1017 /// * `device` - The device to create the tensor on.
1018 ///
1019 /// # Returns
1020 ///
1021 /// The tensor with the given shape and random values.
1022 fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
1023
1024 /// Creates a new tensor with values from the given range with the given step size.
1025 ///
1026 /// # Arguments
1027 ///
1028 /// * `range` - The range of values.
1029 /// * `step` - The step size.
1030 /// * `device` - The device to create the tensor on.
1031 ///
1032 /// # Returns
1033 ///
1034 /// The tensor with the given values.
1035 fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1036 let value = range
1037 .step_by(step)
1038 .map(|i| i.elem())
1039 .collect::<Vec<IntElem<B>>>();
1040 let shape = Shape::new([value.len()]);
1041 let data = TensorData::new(value, shape);
1042 B::int_from_data(data, device)
1043 }
1044
1045 /// Creates a new tensor with values from the given range.
1046 ///
1047 /// # Arguments
1048 ///
1049 /// * `range` - The range of values.
1050 /// * `device` - The device to create the tensor on.
1051 ///
1052 /// # Returns
1053 ///
1054 /// The tensor with the given values.
1055 ///
1056 /// # Remarks
1057 ///
1058 /// Uses `arange_step` with a step size of 1 under the hood.
1059 fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1060 Self::int_arange_step(range, 1, device)
1061 }
1062
1063 /// Tests if any element in the int `tensor` evaluates to True.
1064 ///
1065 /// # Arguments
1066 ///
1067 /// * `tensor` - The tensor to test.
1068 ///
1069 /// # Returns
1070 ///
1071 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1072 fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1073 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1074 let bool_tensor = B::bool_not(bool_tensor);
1075 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1076 B::int_greater_elem(sum, 0.elem())
1077 }
1078
1079 /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1080 ///
1081 /// # Arguments
1082 ///
1083 /// * `tensor` - The tensor to test.
1084 /// * `dim` - The axis along which to test.
1085 ///
1086 /// # Returns
1087 ///
1088 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1089 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1090 /// evaluates to True, False otherwise.
1091 fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1092 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1093 let bool_tensor = B::bool_not(bool_tensor);
1094 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1095 B::int_greater_elem(sum, 0.elem())
1096 }
1097
1098 /// Tests if all elements in the int `tensor` evaluate to True.
1099 ///
1100 /// # Arguments
1101 ///
1102 /// * `tensor` - The tensor to test.
1103 ///
1104 /// # Returns
1105 ///
1106 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1107 /// evaluate to True, False otherwise.
1108 fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1109 let num_elems = tensor.shape().num_elements();
1110 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1111 let bool_tensor = B::bool_not(bool_tensor);
1112 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1113 B::int_equal_elem(sum, (num_elems as i32).elem())
1114 }
1115
1116 /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1117 ///
1118 /// # Arguments
1119 ///
1120 /// * `tensor` - The tensor to test.
1121 /// * `dim` - The axis along which to test.
1122 ///
1123 /// # Returns
1124 ///
1125 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1126 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1127 /// evaluates to True, False otherwise.
1128 fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1129 let num_elems = tensor.shape().dims[dim];
1130 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1131 let bool_tensor = B::bool_not(bool_tensor);
1132 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1133 B::int_equal_elem(sum, (num_elems as i32).elem())
1134 }
1135
1136 /// Returns the signs of the int `tensor`.
1137 ///
1138 /// # Arguments
1139 ///
1140 /// * `tensor` - The tensor to extract the signs from.
1141 ///
1142 /// # Returns
1143 ///
1144 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1145 fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1146 let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor));
1147 let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1148 let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1149
1150 let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1151 result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1152 result
1153 }
1154
1155 /// Broadcasts the int `tensor` to the given `shape`.
1156 fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1157
1158 /// Sort the elements of the input `tensor` by value along a given dimension.
1159 ///
1160 /// This sort is unstable (i.e., may reorder equal elements).
1161 ///
1162 /// # Arguments
1163 ///
1164 /// * `tensor` - The input tensor.
1165 /// * `dim` - The axis along which to sort.
1166 /// * `descending` - The sorting order.
1167 ///
1168 /// # Returns
1169 ///
1170 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1171 fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1172 sort::<B, Int>(tensor, dim, descending)
1173 }
1174
1175 /// Sort the elements of the input `tensor` by value along a given dimension.
1176 ///
1177 /// This sort is unstable (i.e., may reorder equal elements).
1178 ///
1179 /// # Arguments
1180 ///
1181 /// * `tensor` - The input tensor.
1182 /// * `dim` - The axis along which to sort.
1183 ///
1184 /// # Returns
1185 ///
1186 /// A tensor with the same shape as the input tensor and corresponding indices, where
1187 /// the elements are sorted by value and the indices map back to the original input tensor.
1188 fn int_sort_with_indices(
1189 tensor: IntTensor<B>,
1190 dim: usize,
1191 descending: bool,
1192 ) -> (IntTensor<B>, IntTensor<B>) {
1193 sort_with_indices::<B, Int>(tensor, dim, descending)
1194 }
1195
1196 /// Returns the indices that sort the elements of the input `tensor` by value
1197 /// along a given dimension.
1198 ///
1199 /// This sort is unstable (i.e., may reorder equal elements).
1200 ///
1201 /// # Arguments
1202 ///
1203 /// * `tensor` - The input tensor.
1204 /// * `dim` - The axis along which to sort.
1205 /// * `descending` - The sorting order.
1206 ///
1207 /// # Returns
1208 ///
1209 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1210 fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1211 argsort::<B, Int>(tensor, dim, descending)
1212 }
1213
1214 /// Bitwise AND operation for Int Tensors
1215 fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1216
1217 /// Bitwise AND operation for Int Tensors with a scalar
1218 fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1219
1220 /// Bitwise OR operation for Int Tensors
1221 fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1222
1223 /// Bitwise OR operation for Int Tensors with a scalar
1224 fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1225
1226 /// Bitwise XOR operation for Int Tensors
1227 fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1228
1229 /// Bitwise XOR operation for Int Tensors with a scalar
1230 fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1231
1232 /// Bitwise NOT operation for Int Tensors
1233 fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1234
1235 /// Bitwise left shift operation for Int Tensors
1236 fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1237
1238 /// Bitwise left shift operation for Int Tensors with a scalar
1239 fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1240
1241 /// Bitwise right shift operation for Int Tensors
1242 fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1243
1244 /// Bitwise right shift operation for Int Tensors with a scalar
1245 fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1246}