burn_tensor/tensor/ops/
int_tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
4use crate::cast::ToElement;
5use crate::tensor::api::{chunk, narrow, split, split_with_sizes};
6use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData};
7use alloc::vec::Vec;
8use core::future::Future;
9use core::ops::Range;
10
11use crate::{argsort, sort, sort_with_indices, TensorMetadata};
12
13/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16    /// Creates a new int tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    ///
23    /// # Returns
24    ///
25    /// The integer tensor with the given shape.
26    fn int_empty(shape: Shape, device: &Device<B>) -> IntTensor<B>;
27
28    /// Converts the tensor to a data structure.
29    ///
30    /// # Arguments
31    ///
32    /// * `tensor` - The tensor.
33    ///
34    /// # Returns
35    ///
36    /// The data structure with the tensor's data.
37    fn int_into_data(tensor: IntTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
38
39    /// Creates a tensor from the data structure.
40    ///
41    /// # Arguments
42    ///
43    /// * `data` - The data structure.
44    /// * `device` - The device to create the tensor on.
45    ///
46    /// # Returns
47    ///
48    /// The tensor with the data.
49    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
50
51    /// Gets the device of the tensor.
52    ///
53    /// # Arguments
54    ///
55    /// * `tensor` - The tensor.
56    ///
57    /// # Returns
58    ///
59    /// The device of the tensor.
60    fn int_device(tensor: &IntTensor<B>) -> Device<B>;
61
62    /// Moves the tensor to the given device.
63    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
64
65    /// Reshapes the tensor.
66    ///
67    /// # Arguments
68    ///
69    /// * `tensor` - The tensor.
70    /// * `shape` - The new shape.
71    ///
72    /// # Returns
73    ///
74    /// The tensor with the new shape.
75    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
76
77    /// Gets the element at the given indices.
78    ///
79    /// # Arguments
80    ///
81    /// * `tensor` - The tensor.
82    /// * `indices` - The indices.
83    ///
84    /// # Returns
85    ///
86    /// The elements at the given indices.
87    fn int_slice(tensor: IntTensor<B>, indices: &[Range<usize>]) -> IntTensor<B>;
88
89    /// Sets the element at the given indices.
90    ///
91    /// # Arguments
92    ///
93    /// * `tensor` - The tensor.
94    /// * `indices` - The indices.
95    ///
96    /// # Returns
97    ///
98    /// The tensor with the element at the given indices set.
99    fn int_slice_assign(
100        tensor: IntTensor<B>,
101        indices: &[Range<usize>],
102        value: IntTensor<B>,
103    ) -> IntTensor<B>;
104
105    /// Converts int tensor to float tensor.
106    ///
107    /// # Arguments
108    ///
109    /// * `tensor` - The tensor.
110    ///
111    /// # Returns
112    ///
113    /// The int tensor with the same data as the float tensor.
114    fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
115
116    /// Fills the tensor with values from the source tensor if the mask is true at the given
117    /// indices.
118    ///
119    /// # Arguments
120    ///
121    /// * `tensor` - The tensor.
122    /// * `mask` - The mask.
123    /// * `source` - The source tensor.
124    ///
125    /// # Returns
126    ///
127    /// The tensor with the values filled.
128    fn int_mask_where(
129        tensor: IntTensor<B>,
130        mask: BoolTensor<B>,
131        source: IntTensor<B>,
132    ) -> IntTensor<B>;
133
134    /// Fills the tensor with the given value if the mask is true at the given indices.
135    ///
136    /// # Arguments
137    ///
138    /// * `tensor` - The tensor.
139    /// * `mask` - The mask.
140    /// * `value` - The value.
141    ///
142    /// # Returns
143    ///
144    /// The tensor with the values filled.
145    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
146
147    /// Gather elements from the tensor at the given indices.
148    ///
149    /// # Arguments
150    ///
151    /// * `dim` - The dimension to gather from.
152    /// * `tensor` - The tensor.
153    /// * `indices` - The indices.
154    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
155
156    /// Scatter a given value to the tensor at the given indices.
157    ///
158    /// # Arguments
159    ///
160    /// * `dim` - The dimension to scatter to.
161    /// * `tensor` - The tensor.
162    /// * `indices` - The indices.
163    /// * `value` - The value.
164    ///
165    /// # Returns
166    ///
167    /// The tensor with the values scattered.
168    fn int_scatter(
169        dim: usize,
170        tensor: IntTensor<B>,
171        indices: IntTensor<B>,
172        value: IntTensor<B>,
173    ) -> IntTensor<B>;
174
175    /// Select tensor elements along the given dimension corresponding to the given indices.
176    ///
177    /// # Arguments
178    ///
179    /// * `tensor` - The tensor.
180    /// * `dim` - The dimension to select from.
181    /// * `indices` - The indices.
182    ///
183    /// # Returns
184    ///
185    /// The tensor with the selected elements.
186    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
187
188    /// Assign the selected elements along the given dimension corresponding to the given indices
189    /// to the given value.
190    ///
191    /// # Arguments
192    ///
193    /// * `tensor` - The tensor.
194    /// * `dim` - The dimension to select from.
195    /// * `indices` - The indices.
196    /// * `value` - The value.
197    ///
198    /// # Returns
199    ///
200    /// The tensor with the selected elements assigned to the given value.
201    fn int_select_assign(
202        tensor: IntTensor<B>,
203        dim: usize,
204        indices: IntTensor<B>,
205        value: IntTensor<B>,
206    ) -> IntTensor<B>;
207
208    /// Repeats the tensor along the given dimension the given number of times.
209    ///
210    /// # Arguments
211    ///
212    /// * `tensor` - The tensor.
213    /// * `dim` - The dimension to repeat.
214    /// * `times` - The number of times to repeat.
215    ///
216    /// # Returns
217    ///
218    /// The tensor with the given dimension repeated the given number of times.
219    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
220        repeat_with_slice_assign::<B, Int>(tensor, dim, times)
221    }
222
223    /// Concatenates the given tensors along the given dimension.
224    ///
225    /// # Arguments
226    ///
227    /// * `tensors` - The tensors.
228    /// * `dim` - The dimension to concatenate along.
229    ///
230    /// # Returns
231    ///
232    /// The concatenated tensor.
233    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
234        cat_with_slice_assign::<B, Int>(tensors, dim)
235    }
236
237    /// Element-wise equality comparison.
238    ///
239    /// # Arguments
240    ///
241    /// * `lhs` - The left hand side tensor.
242    /// * `rhs` - The right hand side tensor.
243    ///
244    /// # Returns
245    ///
246    /// The boolean tensor with the result of the comparison.
247    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
248
249    /// Element-wise non-equality comparison.
250    ///
251    /// # Arguments
252    ///
253    /// * `lhs` - The left hand side tensor.
254    /// * `rhs` - The right hand side tensor.
255    ///
256    /// # Returns
257    ///
258    /// The boolean tensor with the result of the comparison.
259    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
260        let equal_tensor = B::int_equal(lhs, rhs);
261        B::bool_not(equal_tensor)
262    }
263
264    /// Element-wise equality comparison with a scalar.
265    ///
266    /// # Arguments
267    ///
268    /// * `lhs` - The left hand side tensor.
269    /// * `rhs` - The right hand side scalar.
270    ///
271    /// # Returns
272    ///
273    /// The boolean tensor with the result of the comparison.
274    fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
275
276    /// Element-wise non-equality comparison with a scalar.
277    ///
278    /// # Arguments
279    ///
280    /// * `lhs` - The left hand side tensor.
281    /// * `rhs` - The right hand side scalar.
282    ///
283    /// # Returns
284    ///
285    /// The boolean tensor with the result of the comparison.
286    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
287        let equal_tensor = B::int_equal_elem(lhs, rhs);
288        B::bool_not(equal_tensor)
289    }
290
291    /// Element-wise greater than comparison.
292    ///
293    /// # Arguments
294    ///
295    /// * `lhs` - The left hand side tensor.
296    /// * `rhs` - The right hand side tensor.
297    ///
298    /// # Returns
299    ///
300    /// The boolean tensor with the result of the comparison.
301    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
302
303    /// Element-wise greater than comparison with a scalar.
304    ///
305    /// # Arguments
306    ///
307    /// * `lhs` - The left hand side tensor.
308    /// * `rhs` - The right hand side scalar.
309    ///
310    /// # Returns
311    ///
312    /// The boolean tensor with the result of the comparison.
313    fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
314
315    /// Element-wise greater than or equal comparison.
316    ///
317    /// # Arguments
318    ///
319    /// * `lhs` - The left hand side tensor.
320    /// * `rhs` - The right hand side tensor.
321    ///
322    /// # Returns
323    ///
324    /// The boolean tensor with the result of the comparison.
325    fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
326
327    /// Element-wise greater than or equal comparison with a scalar.
328    ///
329    /// # Arguments
330    ///
331    /// * `lhs` - The left hand side tensor.
332    /// * `rhs` - The right hand side scalar.
333    ///
334    /// # Returns
335    ///
336    /// The boolean tensor with the result of the comparison.
337    fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
338
339    /// Element-wise less than comparison.
340    ///
341    /// # Arguments
342    ///
343    /// * `lhs` - The left hand side tensor.
344    /// * `rhs` - The right hand side tensor.
345    ///
346    /// # Returns
347    ///
348    /// The boolean tensor with the result of the comparison.
349    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
350
351    /// Element-wise less than comparison with a scalar.
352    ///
353    /// # Arguments
354    ///
355    /// * `lhs` - The left hand side tensor.
356    /// * `rhs` - The right hand side scalar.
357    ///
358    /// # Returns
359    ///
360    /// The boolean tensor with the result of the comparison.
361    fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
362
363    /// Element-wise less than or equal comparison.
364    ///
365    /// # Arguments
366    ///
367    /// * `lhs` - The left hand side tensor.
368    /// * `rhs` - The right hand side tensor.
369    ///
370    /// # Returns
371    ///
372    /// The boolean tensor with the result of the comparison.
373    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
374
375    /// Element-wise less than or equal comparison with a scalar.
376    ///
377    /// # Arguments
378    ///
379    /// * `lhs` - The left hand side tensor.
380    /// * `rhs` - The right hand side scalar.
381    ///
382    /// # Returns
383    ///
384    /// The boolean tensor with the result of the comparison.
385    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
386
387    // ====  NUMERIC ==== //
388
389    /// Element-wise addition.
390    ///
391    /// # Arguments
392    ///
393    /// * `lhs` - The left hand side tensor.
394    /// * `rhs` - The right hand side tensor.
395    ///
396    /// # Returns
397    ///
398    /// The result of the addition.
399    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
400
401    /// Element-wise addition with a scalar.
402    ///
403    /// # Arguments
404    ///
405    /// * `lhs` - The left hand side tensor.
406    /// * `rhs` - The right hand side scalar.
407    ///
408    /// # Returns
409    ///
410    /// The result of the addition.
411    fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
412
413    /// Element-wise power with a IntTensor.
414    ///
415    /// # Arguments
416    ///
417    /// * `lhs` - The left hand side IntTensor.
418    /// * `rhs` - The right hand side IntTensor.
419    ///
420    /// # Returns
421    ///
422    /// The elements of `lhs` raised to the power of the elements of `rhs`.
423    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
424        B::float_into_int(B::float_powf(
425            B::int_into_float(lhs),
426            B::int_into_float(rhs),
427        ))
428    }
429
430    /// Element-wise power with a floatTensor.
431    ///
432    /// # Arguments
433    ///
434    /// * `lhs` - The left hand side tensor.
435    /// * `rhs` - The right hand side floatTensor.
436    ///
437    /// # Returns
438    ///
439    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
440    fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
441        B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
442    }
443
444    /// Element-wise power with a scalar.
445    ///
446    /// # Arguments
447    ///
448    /// * `lhs` - The left hand side tensor.
449    /// * `rhs` - The right hand side scalar.
450    ///
451    /// # Returns
452    ///
453    /// The elements of `lhs` raised to the value of `rhs`.
454    fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
455        B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs.to_f32()))
456    }
457
458    /// Element-wise power with a floatTensor.
459    ///
460    /// # Arguments
461    ///
462    /// * `lhs` - The left hand side tensor.
463    /// * `rhs` - The right hand side scalar.
464    ///
465    /// # Returns
466    ///
467    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
468    fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
469        B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs))
470    }
471
472    /// Clamps a tensor under a minimum value.
473    ///
474    /// # Arguments
475    ///
476    /// * `tensor` - The tensor to clamp.
477    /// * `min` - The minimum value.
478    ///
479    /// # Returns
480    ///
481    /// The clamped tensor.
482    fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
483        let mask = Self::int_lower_elem(tensor.clone(), min);
484        Self::int_mask_fill(tensor, mask, min)
485    }
486
487    /// Clamps a tensor over a maximum value.
488    ///
489    /// # Arguments
490    ///
491    /// * `tensor` - The tensor to clamp.
492    /// * `max` - The maximum value.
493    ///
494    /// # Returns
495    ///
496    /// The clamped tensor.
497    fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
498        let mask = Self::int_greater_elem(tensor.clone(), max);
499        Self::int_mask_fill(tensor, mask, max)
500    }
501
502    /// Clamps a tensor between a minimum and maximum value.
503    ///
504    /// # Arguments
505    ///
506    /// * `tensor` - The tensor to clamp.
507    /// * `min` - The minimum value.
508    /// * `max` - The maximum value.
509    ///
510    /// # Returns
511    ///
512    /// The clamped tensor.
513    fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
514        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
515    }
516
517    /// Element-wise subtraction.
518    ///
519    /// # Arguments
520    ///
521    /// * `lhs` - The left hand side tensor.
522    /// * `rhs` - The right hand side tensor.
523    ///
524    /// # Returns
525    ///
526    /// The result of the subtraction.
527    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
528
529    /// Element-wise subtraction with a scalar.
530    ///
531    /// # Arguments
532    ///
533    /// * `lhs` - The left hand side tensor.
534    /// * `rhs` - The right hand side scalar.
535    ///
536    /// # Returns
537    ///
538    /// The result of the subtraction.
539    fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
540
541    /// Element-wise multiplication.
542    ///
543    /// # Arguments
544    ///
545    /// * `lhs` - The left hand side tensor.
546    /// * `rhs` - The right hand side tensor.
547    ///
548    /// # Returns
549    ///
550    /// The result of the multiplication.
551    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
552
553    /// Element-wise multiplication with a scalar.
554    ///
555    /// # Arguments
556    ///
557    /// * `lhs` - The left hand side tensor.
558    /// * `rhs` - The right hand side scalar.
559    ///
560    /// # Returns
561    ///
562    /// The result of the multiplication.
563    fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
564
565    /// Element-wise division.
566    ///
567    /// # Arguments
568    ///
569    /// * `lhs` - The left hand side tensor.
570    /// * `rhs` - The right hand side tensor.
571    ///
572    /// # Returns
573    ///
574    /// The result of the division.
575    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
576
577    /// Element-wise division with a scalar.
578    ///
579    /// # Arguments
580    ///
581    /// * `lhs` - The left hand side tensor.
582    /// * `rhs` - The right hand side scalar.
583    ///
584    /// # Returns
585    ///
586    /// The result of the division.
587    fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
588
589    /// Element-wise modulus.
590    ///
591    /// # Arguments
592    /// * `lhs` - The left hand side tensor.
593    /// * `rhs` - The right hand side scalar.
594    ///
595    /// # Returns
596    ///
597    /// The result of applying the modulus of the scalar to the tensor.
598    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
599
600    /// Element-wise modulus with a scalar.
601    ///
602    /// # Arguments
603    /// * `lhs` - The left hand side tensor.
604    /// * `rhs` - The right hand side scalar.
605    ///
606    /// # Returns
607    ///
608    /// The result of applying the modulus of the scalar to the tensor.
609    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
610
611    /// Element-wise negation.
612    ///
613    /// # Arguments
614    ///
615    /// * `tensor` - The tensor to negate.
616    ///
617    /// # Returns
618    ///
619    /// The negated tensor.
620    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
621        Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
622    }
623
624    /// Creates a tensor of zeros.
625    ///
626    /// # Arguments
627    ///
628    /// * `shape` - The shape of the tensor.
629    /// * `device` - The device to create the tensor on.
630    ///
631    /// # Returns
632    ///
633    /// The tensor of zeros.
634    fn int_zeros(shape: Shape, device: &Device<B>) -> IntTensor<B>;
635
636    /// Creates a tensor of ones.
637    ///
638    /// # Arguments
639    ///
640    /// * `shape` - The shape of the tensor.
641    /// * `device` - The device to create the tensor on.
642    ///
643    /// # Returns
644    ///
645    /// The tensor of ones.
646    fn int_ones(shape: Shape, device: &Device<B>) -> IntTensor<B>;
647
648    /// Creates a tensor filled with given value.
649    ///
650    /// # Arguments
651    ///
652    /// * `shape` - The shape of the tensor.
653    /// * `fill_value` - The value with which to fill the tensor.
654    /// * `device` - The device to create the tensor on.
655    ///
656    /// # Returns
657    ///
658    /// The tensor filled with given value
659    fn int_full(shape: Shape, fill_value: IntElem<B>, device: &Device<B>) -> IntTensor<B> {
660        Self::int_add_scalar(Self::int_zeros(shape, device), fill_value)
661    }
662
663    /// Sums all elements in the tensor.
664    ///
665    /// # Arguments
666    ///
667    /// * `tensor` - The tensor to sum.
668    ///
669    /// # Returns
670    ///
671    /// The sum of all elements in the tensor.
672    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
673
674    /// Sums all elements in the tensor along a dimension.
675    ///
676    /// # Arguments
677    ///
678    /// * `tensor` - The tensor to sum.
679    /// * `dim` - The dimension to sum along.
680    ///
681    /// # Returns
682    ///
683    /// The sum of all elements in the tensor along the dimension.
684    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
685
686    /// Computes the product of all elements in the tensor.
687    ///
688    /// # Arguments
689    ///
690    /// * `tensor` - The tensor to compute the product of.
691    ///
692    /// # Returns
693    ///
694    /// The product of all elements in the tensor.
695    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
696
697    /// Computes the product of all elements in the tensor along a dimension.
698    ///
699    /// # Arguments
700    ///
701    /// * `tensor` - The tensor to compute the product of.
702    /// * `dim` - The dimension to compute the product along.
703    ///
704    /// # Returns
705    ///
706    /// The product of all elements in the tensor along the dimension.
707    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
708
709    /// Computes the mean of all elements in the tensor.
710    ///
711    /// # Arguments
712    ///
713    /// * `tensor` - The tensor to compute the mean of.
714    ///
715    /// # Returns
716    ///
717    /// The mean of all elements in the tensor.
718    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
719        let num_elems = tensor.shape().num_elements();
720        B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
721    }
722
723    /// Computes the mean of all elements in the tensor along a dimension.
724    ///
725    /// # Arguments
726    ///
727    /// * `tensor` - The tensor to compute the mean of.
728    ///
729    /// # Returns
730    ///
731    /// The mean of all elements in the tensor along the dimension.
732    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
733
734    /// Gets the indices of the maximum elements along a dimension.
735    ///
736    /// # Arguments
737    ///
738    /// * `tensor` - The tensor to get the maximum indices of.
739    /// * `dim` - The dimension to get the maximum indices along.
740    ///
741    /// # Returns
742    ///
743    /// The indices of the maximum elements along the dimension.
744    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
745
746    /// Gets the indices of the minimum elements along a dimension.
747    ///
748    /// # Arguments
749    ///
750    /// * `tensor` - The tensor to get the minimum indices of.
751    /// * `dim` - The dimension to get the minimum indices along.
752    ///
753    /// # Returns
754    ///
755    /// The indices of the minimum elements along the dimension.
756    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
757
758    /// Gets the maximum element in the tensor.
759    ///
760    /// # Arguments
761    ///
762    /// * `tensor` - The tensor to get the maximum element of.
763    ///
764    /// # Returns
765    ///
766    /// The maximum element in the tensor.
767    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
768        let shape = tensor.shape();
769        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
770
771        B::int_max_dim(tensor, 0)
772    }
773
774    /// Gets the maximum element in the tensor along a dimension.
775    ///
776    /// # Arguments
777    ///
778    /// * `tensor` - The tensor to get the maximum element of.
779    /// * `dim` - The dimension to get the maximum element along.
780    ///
781    /// # Returns
782    ///
783    /// The maximum element in the tensor along the dimension.
784    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
785        let index = B::int_argmax(tensor.clone(), dim);
786        let ndim = tensor.shape().num_dims();
787
788        B::int_gather(ndim - 1, tensor, index)
789    }
790
791    /// Gets the maximum elements and corresponding indices along a dimension.
792    ///
793    /// # Arguments
794    ///
795    /// * `tensor` - The tensor to get the maximum elements and indices of.
796    /// * `dim` - The dimension to get the maximum elements and indices along.
797    ///
798    /// # Returns
799    ///
800    /// The maximum elements and corresponding indices along the dimension.
801    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
802        let index = B::int_argmax(tensor.clone(), dim);
803        let values = B::int_gather(dim, tensor, index.clone());
804
805        (values, index)
806    }
807
808    /// Gets the minimum element in the tensor.
809    ///
810    /// # Arguments
811    ///
812    /// * `tensor` - The tensor to get the minimum element of.
813    ///
814    /// # Returns
815    ///
816    /// The minimum element in the tensor.
817    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
818        let shape = tensor.shape();
819        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
820
821        B::int_min_dim(tensor, 0)
822    }
823
824    /// Gets the minimum elements in the tensor along a dimension.
825    ///
826    /// # Arguments
827    ///
828    /// * `tensor` - The tensor to get the minimum element of.
829    /// * `dim` - The dimension to get the minimum element along.
830    ///
831    /// # Returns
832    ///
833    /// The minimum element in the tensor along the dimension.
834    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
835        let index = B::int_argmin(tensor.clone(), dim);
836        let ndim = tensor.shape().num_dims();
837
838        B::int_gather(ndim - 1, tensor, index)
839    }
840
841    /// Gets the minimum elements and corresponding indices along a dimension.
842    ///
843    /// # Arguments
844    ///
845    /// * `tensor` - The tensor to get the minimum elements and indices of.
846    /// * `dim` - The dimension to get the minimum elements and indices along.
847    ///
848    /// # Returns
849    ///
850    /// The minimum elements and corresponding indices along the dimension.
851    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
852        let indices = B::int_argmin(tensor.clone(), dim);
853        let ndim = tensor.shape().num_dims();
854        let values = B::int_gather(ndim - 1, tensor, indices.clone());
855
856        (values, indices)
857    }
858
859    /// Returns a new tensor with absolute values.
860    ///
861    /// # Arguments
862    ///
863    /// * `tensor` - The tensor to take absolute value of.
864    ///
865    /// # Returns
866    ///
867    /// A tensor with the same shape as `tensor` with absolute values.
868    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
869
870    /// Transposes an int tensor.
871    ///
872    /// # Arguments
873    ///
874    /// * `tensor` - The tensor to transpose.
875    ///
876    /// # Returns
877    ///
878    /// The transposed tensor.
879    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
880        let ndims = tensor.shape().num_dims();
881        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
882    }
883
884    /// Swaps two dimensions of an int tensor.
885    ///
886    /// # Arguments
887    ///
888    /// * `tensor` - The tensor to swap the dimensions of.
889    /// * `dim1` - The first dimension to swap.
890    /// * `dim2` - The second dimension to swap.
891    ///
892    /// # Returns
893    ///
894    /// The tensor with the dimensions swapped.
895    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
896
897    /// Permutes the dimensions of a tensor.
898    ///
899    /// # Arguments
900    ///
901    /// * `tensor` - The tensor to permute the dimensions of.
902    /// * `axes` - The new order of the dimensions.
903    /// # Returns
904    ///
905    /// The tensor with the dimensions permuted.
906    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
907
908    /// Reverse the order of elements in a tensor along the given axes.
909    ///
910    /// # Arguments
911    ///
912    /// * `tensor` - The tensor to reverse.
913    /// * `axes` - The axes to reverse.
914    ///
915    /// The tensor with the elements reversed.
916    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
917
918    /// Returns a new tensor with the given dimension narrowed to the given range.
919    ///
920    /// # Arguments
921    ///
922    /// * `dim` - The dimension along which the tensor will be narrowed.
923    /// * `start` - The starting point of the given range.
924    /// * `length` - The ending point of the given range.
925    /// # Panics
926    ///
927    /// - If the dimension is greater than the number of dimensions of the tensor.
928    /// - If the given range exceeds the number of elements on the given dimension.
929    ///
930    /// # Returns
931    ///
932    /// A new tensor with the given dimension narrowed to the given range.
933    fn int_narrow(tensor: IntTensor<B>, dim: usize, start: usize, length: usize) -> IntTensor<B> {
934        narrow::<B, Int>(tensor, dim, start, length)
935    }
936
937    /// Split the tensor along the given dimension into chunks.
938    ///
939    /// # Arguments
940    ///
941    /// * `tensor` - The tensor.
942    /// * `chunks` - The number of chunks to be produced.
943    /// * `times` - The dimension along which the tensor will be split.
944    ///
945    /// # Returns
946    ///
947    /// A vector of tensors
948    fn int_chunk(tensor: IntTensor<B>, chunks: usize, dim: usize) -> Vec<IntTensor<B>> {
949        chunk::<B, Int>(tensor, chunks, dim)
950    }
951
952    /// Split the tensor along the given dimension into chunks of `split_size`.
953    ///
954    /// # Arguments
955    ///
956    /// * `tensor` - The tensor.
957    /// * `split_size` - The size of a single chunk.
958    /// * `times` - The dimension along which the tensor will be split.
959    ///
960    /// # Returns
961    ///
962    /// A vector of tensors.
963    fn int_split(tensor: IntTensor<B>, split_size: usize, dim: usize) -> Vec<IntTensor<B>> {
964        split::<B, Int>(tensor, split_size, dim)
965    }
966
967    /// Split the tensor along the given dimension into chunks with sizes in
968    /// `dim` according to `split_sizes`.
969    ///
970    /// # Arguments
971    ///
972    /// * `tensor` - The tensor.
973    /// * `split_sizes` - Vector of sizes for each chunk.
974    /// * `times` - The dimension along which the tensor will be split.
975    ///
976    /// # Returns
977    ///
978    /// A vector of tensors.
979    fn int_split_with_sizes(
980        tensor: IntTensor<B>,
981        split_sizes: Vec<usize>,
982        dim: usize,
983    ) -> Vec<IntTensor<B>> {
984        split_with_sizes::<B, Int>(tensor, split_sizes, dim)
985    }
986
987    /// Creates a new int tensor with random values.
988    ///
989    ///  # Arguments
990    ///  * `shape` - The shape of the tensor.
991    ///  * `distribution` - The distribution to sample from.
992    ///  * `device` - The device to create the tensor on.
993    ///
994    ///  # Returns
995    ///
996    ///  The tensor with the given shape and random values.
997    fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
998
999    /// Creates a new tensor with values from the given range with the given step size.
1000    ///
1001    /// # Arguments
1002    ///
1003    /// * `range` - The range of values.
1004    /// * `step` - The step size.
1005    /// * `device` - The device to create the tensor on.
1006    ///
1007    /// # Returns
1008    ///
1009    /// The tensor with the given values.
1010    fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1011        let value = range
1012            .step_by(step)
1013            .map(|i| i.elem())
1014            .collect::<Vec<IntElem<B>>>();
1015        let shape = Shape::new([value.len()]);
1016        let data = TensorData::new(value, shape);
1017        B::int_from_data(data, device)
1018    }
1019
1020    /// Creates a new tensor with values from the given range.
1021    ///
1022    /// # Arguments
1023    ///
1024    /// * `range` - The range of values.
1025    /// * `device` - The device to create the tensor on.
1026    ///
1027    /// # Returns
1028    ///
1029    /// The tensor with the given values.
1030    ///
1031    /// # Remarks
1032    ///
1033    /// Uses `arange_step` with a step size of 1 under the hood.
1034    fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1035        Self::int_arange_step(range, 1, device)
1036    }
1037
1038    /// Tests if any element in the int `tensor` evaluates to True.
1039    ///
1040    /// # Arguments
1041    ///
1042    /// * `tensor` - The tensor to test.
1043    ///
1044    /// # Returns
1045    ///
1046    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1047    fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1048        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1049        let bool_tensor = B::bool_not(bool_tensor);
1050        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1051        B::int_greater_elem(sum, 0.elem())
1052    }
1053
1054    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1055    ///
1056    /// # Arguments
1057    ///
1058    /// * `tensor` - The tensor to test.
1059    /// * `dim` - The axis along which to test.
1060    ///
1061    /// # Returns
1062    ///
1063    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1064    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1065    /// evaluates to True, False otherwise.
1066    fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1067        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1068        let bool_tensor = B::bool_not(bool_tensor);
1069        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1070        B::int_greater_elem(sum, 0.elem())
1071    }
1072
1073    /// Tests if all elements in the int `tensor` evaluate to True.
1074    ///
1075    /// # Arguments
1076    ///
1077    /// * `tensor` - The tensor to test.
1078    ///
1079    /// # Returns
1080    ///
1081    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1082    /// evaluate to True, False otherwise.
1083    fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1084        let num_elems = tensor.shape().num_elements();
1085        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1086        let bool_tensor = B::bool_not(bool_tensor);
1087        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1088        B::int_equal_elem(sum, (num_elems as i32).elem())
1089    }
1090
1091    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1092    ///
1093    /// # Arguments
1094    ///
1095    /// * `tensor` - The tensor to test.
1096    /// * `dim` - The axis along which to test.
1097    ///
1098    /// # Returns
1099    ///
1100    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1101    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1102    /// evaluates to True, False otherwise.
1103    fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1104        let num_elems = tensor.shape().dims[dim];
1105        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1106        let bool_tensor = B::bool_not(bool_tensor);
1107        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1108        B::int_equal_elem(sum, (num_elems as i32).elem())
1109    }
1110
1111    /// Returns the signs of the int `tensor`.
1112    ///
1113    /// # Arguments
1114    ///
1115    /// * `tensor` - The tensor to extract the signs from.
1116    ///
1117    /// # Returns
1118    ///
1119    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1120    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1121        let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor));
1122        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1123        let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1124
1125        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1126        result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1127        result
1128    }
1129
1130    /// Broadcasts the int `tensor` to the given `shape`.
1131    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1132
1133    /// Sort the elements of the input `tensor` by value along a given dimension.
1134    ///
1135    /// This sort is unstable (i.e., may reorder equal elements).
1136    ///
1137    /// # Arguments
1138    ///
1139    /// * `tensor` - The input tensor.
1140    /// * `dim` - The axis along which to sort.
1141    /// * `descending` - The sorting order.
1142    ///
1143    /// # Returns
1144    ///
1145    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1146    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1147        sort::<B, Int>(tensor, dim, descending)
1148    }
1149
1150    /// Sort the elements of the input `tensor` by value along a given dimension.
1151    ///
1152    /// This sort is unstable (i.e., may reorder equal elements).
1153    ///
1154    /// # Arguments
1155    ///
1156    /// * `tensor` - The input tensor.
1157    /// * `dim` - The axis along which to sort.
1158    ///
1159    /// # Returns
1160    ///
1161    /// A tensor with the same shape as the input tensor and corresponding indices, where
1162    /// the elements are sorted by value and the indices map back to the original input tensor.
1163    fn int_sort_with_indices(
1164        tensor: IntTensor<B>,
1165        dim: usize,
1166        descending: bool,
1167    ) -> (IntTensor<B>, IntTensor<B>) {
1168        sort_with_indices::<B, Int>(tensor, dim, descending)
1169    }
1170
1171    /// Returns the indices that sort the elements of the input `tensor` by value
1172    /// along a given dimension.
1173    ///
1174    /// This sort is unstable (i.e., may reorder equal elements).
1175    ///
1176    /// # Arguments
1177    ///
1178    /// * `tensor` - The input tensor.
1179    /// * `dim` - The axis along which to sort.
1180    /// * `descending` - The sorting order.
1181    ///
1182    /// # Returns
1183    ///
1184    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1185    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1186        argsort::<B, Int>(tensor, dim, descending)
1187    }
1188}