burn_tensor/tensor/ops/
int_tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
4use crate::cast::ToElement;
5use crate::tensor::api::{chunk, narrow, split, split_with_sizes};
6use crate::{Distribution, ElementConversion, Int, TensorData, backend::Backend, tensor::Shape};
7use alloc::vec::Vec;
8use core::future::Future;
9use core::ops::Range;
10
11use crate::{TensorMetadata, argsort, sort, sort_with_indices};
12
13/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16    /// Creates a new int tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    ///
23    /// # Returns
24    ///
25    /// The integer tensor with the given shape.
26    fn int_empty(shape: Shape, device: &Device<B>) -> IntTensor<B>;
27
28    /// Converts the tensor to a data structure.
29    ///
30    /// # Arguments
31    ///
32    /// * `tensor` - The tensor.
33    ///
34    /// # Returns
35    ///
36    /// The data structure with the tensor's data.
37    fn int_into_data(tensor: IntTensor<B>) -> impl Future<Output = TensorData> + 'static + Send;
38
39    /// Creates a tensor from the data structure.
40    ///
41    /// # Arguments
42    ///
43    /// * `data` - The data structure.
44    /// * `device` - The device to create the tensor on.
45    ///
46    /// # Returns
47    ///
48    /// The tensor with the data.
49    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
50
51    /// Gets the device of the tensor.
52    ///
53    /// # Arguments
54    ///
55    /// * `tensor` - The tensor.
56    ///
57    /// # Returns
58    ///
59    /// The device of the tensor.
60    fn int_device(tensor: &IntTensor<B>) -> Device<B>;
61
62    /// Moves the tensor to the given device.
63    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
64
65    /// Reshapes the tensor.
66    ///
67    /// # Arguments
68    ///
69    /// * `tensor` - The tensor.
70    /// * `shape` - The new shape.
71    ///
72    /// # Returns
73    ///
74    /// The tensor with the new shape.
75    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
76
77    /// Gets the element at the given indices.
78    ///
79    /// # Arguments
80    ///
81    /// * `tensor` - The tensor.
82    /// * `indices` - The indices.
83    ///
84    /// # Returns
85    ///
86    /// The elements at the given indices.
87    fn int_slice(tensor: IntTensor<B>, indices: &[Range<usize>]) -> IntTensor<B>;
88
89    /// Sets the element at the given indices.
90    ///
91    /// # Arguments
92    ///
93    /// * `tensor` - The tensor.
94    /// * `indices` - The indices.
95    ///
96    /// # Returns
97    ///
98    /// The tensor with the element at the given indices set.
99    fn int_slice_assign(
100        tensor: IntTensor<B>,
101        indices: &[Range<usize>],
102        value: IntTensor<B>,
103    ) -> IntTensor<B>;
104
105    /// Converts int tensor to float tensor.
106    ///
107    /// # Arguments
108    ///
109    /// * `tensor` - The tensor.
110    ///
111    /// # Returns
112    ///
113    /// The int tensor with the same data as the float tensor.
114    fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
115
116    /// Fills the tensor with values from the source tensor if the mask is true at the given
117    /// indices.
118    ///
119    /// # Arguments
120    ///
121    /// * `tensor` - The tensor.
122    /// * `mask` - The mask.
123    /// * `source` - The source tensor.
124    ///
125    /// # Returns
126    ///
127    /// The tensor with the values filled.
128    fn int_mask_where(
129        tensor: IntTensor<B>,
130        mask: BoolTensor<B>,
131        source: IntTensor<B>,
132    ) -> IntTensor<B>;
133
134    /// Fills the tensor with the given value if the mask is true at the given indices.
135    ///
136    /// # Arguments
137    ///
138    /// * `tensor` - The tensor.
139    /// * `mask` - The mask.
140    /// * `value` - The value.
141    ///
142    /// # Returns
143    ///
144    /// The tensor with the values filled.
145    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
146
147    /// Gather elements from the tensor at the given indices.
148    ///
149    /// # Arguments
150    ///
151    /// * `dim` - The dimension to gather from.
152    /// * `tensor` - The tensor.
153    /// * `indices` - The indices.
154    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
155
156    /// Scatter a given value to the tensor at the given indices.
157    ///
158    /// # Arguments
159    ///
160    /// * `dim` - The dimension to scatter to.
161    /// * `tensor` - The tensor.
162    /// * `indices` - The indices.
163    /// * `value` - The value.
164    ///
165    /// # Returns
166    ///
167    /// The tensor with the values scattered.
168    fn int_scatter(
169        dim: usize,
170        tensor: IntTensor<B>,
171        indices: IntTensor<B>,
172        value: IntTensor<B>,
173    ) -> IntTensor<B>;
174
175    /// Select tensor elements along the given dimension corresponding to the given indices.
176    ///
177    /// # Arguments
178    ///
179    /// * `tensor` - The tensor.
180    /// * `dim` - The dimension to select from.
181    /// * `indices` - The indices.
182    ///
183    /// # Returns
184    ///
185    /// The tensor with the selected elements.
186    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
187
188    /// Assign the selected elements along the given dimension corresponding to the given indices
189    /// to the given value.
190    ///
191    /// # Arguments
192    ///
193    /// * `tensor` - The tensor.
194    /// * `dim` - The dimension to select from.
195    /// * `indices` - The indices.
196    /// * `value` - The value.
197    ///
198    /// # Returns
199    ///
200    /// The tensor with the selected elements assigned to the given value.
201    fn int_select_assign(
202        tensor: IntTensor<B>,
203        dim: usize,
204        indices: IntTensor<B>,
205        value: IntTensor<B>,
206    ) -> IntTensor<B>;
207
208    /// Repeats the tensor along the given dimension the given number of times.
209    ///
210    /// # Arguments
211    ///
212    /// * `tensor` - The tensor.
213    /// * `dim` - The dimension to repeat.
214    /// * `times` - The number of times to repeat.
215    ///
216    /// # Returns
217    ///
218    /// The tensor with the given dimension repeated the given number of times.
219    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
220        repeat_with_slice_assign::<B, Int>(tensor, dim, times)
221    }
222
223    /// Concatenates the given tensors along the given dimension.
224    ///
225    /// # Arguments
226    ///
227    /// * `tensors` - The tensors.
228    /// * `dim` - The dimension to concatenate along.
229    ///
230    /// # Returns
231    ///
232    /// The concatenated tensor.
233    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
234        cat_with_slice_assign::<B, Int>(tensors, dim)
235    }
236
237    /// Element-wise equality comparison.
238    ///
239    /// # Arguments
240    ///
241    /// * `lhs` - The left hand side tensor.
242    /// * `rhs` - The right hand side tensor.
243    ///
244    /// # Returns
245    ///
246    /// The boolean tensor with the result of the comparison.
247    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
248
249    /// Element-wise non-equality comparison.
250    ///
251    /// # Arguments
252    ///
253    /// * `lhs` - The left hand side tensor.
254    /// * `rhs` - The right hand side tensor.
255    ///
256    /// # Returns
257    ///
258    /// The boolean tensor with the result of the comparison.
259    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
260        let equal_tensor = B::int_equal(lhs, rhs);
261        B::bool_not(equal_tensor)
262    }
263
264    /// Element-wise equality comparison with a scalar.
265    ///
266    /// # Arguments
267    ///
268    /// * `lhs` - The left hand side tensor.
269    /// * `rhs` - The right hand side scalar.
270    ///
271    /// # Returns
272    ///
273    /// The boolean tensor with the result of the comparison.
274    fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
275
276    /// Element-wise non-equality comparison with a scalar.
277    ///
278    /// # Arguments
279    ///
280    /// * `lhs` - The left hand side tensor.
281    /// * `rhs` - The right hand side scalar.
282    ///
283    /// # Returns
284    ///
285    /// The boolean tensor with the result of the comparison.
286    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
287        let equal_tensor = B::int_equal_elem(lhs, rhs);
288        B::bool_not(equal_tensor)
289    }
290
291    /// Element-wise greater than comparison.
292    ///
293    /// # Arguments
294    ///
295    /// * `lhs` - The left hand side tensor.
296    /// * `rhs` - The right hand side tensor.
297    ///
298    /// # Returns
299    ///
300    /// The boolean tensor with the result of the comparison.
301    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
302
303    /// Element-wise greater than comparison with a scalar.
304    ///
305    /// # Arguments
306    ///
307    /// * `lhs` - The left hand side tensor.
308    /// * `rhs` - The right hand side scalar.
309    ///
310    /// # Returns
311    ///
312    /// The boolean tensor with the result of the comparison.
313    fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
314
315    /// Element-wise greater than or equal comparison.
316    ///
317    /// # Arguments
318    ///
319    /// * `lhs` - The left hand side tensor.
320    /// * `rhs` - The right hand side tensor.
321    ///
322    /// # Returns
323    ///
324    /// The boolean tensor with the result of the comparison.
325    fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
326
327    /// Element-wise greater than or equal comparison with a scalar.
328    ///
329    /// # Arguments
330    ///
331    /// * `lhs` - The left hand side tensor.
332    /// * `rhs` - The right hand side scalar.
333    ///
334    /// # Returns
335    ///
336    /// The boolean tensor with the result of the comparison.
337    fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
338
339    /// Element-wise less than comparison.
340    ///
341    /// # Arguments
342    ///
343    /// * `lhs` - The left hand side tensor.
344    /// * `rhs` - The right hand side tensor.
345    ///
346    /// # Returns
347    ///
348    /// The boolean tensor with the result of the comparison.
349    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
350
351    /// Element-wise less than comparison with a scalar.
352    ///
353    /// # Arguments
354    ///
355    /// * `lhs` - The left hand side tensor.
356    /// * `rhs` - The right hand side scalar.
357    ///
358    /// # Returns
359    ///
360    /// The boolean tensor with the result of the comparison.
361    fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
362
363    /// Element-wise less than or equal comparison.
364    ///
365    /// # Arguments
366    ///
367    /// * `lhs` - The left hand side tensor.
368    /// * `rhs` - The right hand side tensor.
369    ///
370    /// # Returns
371    ///
372    /// The boolean tensor with the result of the comparison.
373    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
374
375    /// Element-wise less than or equal comparison with a scalar.
376    ///
377    /// # Arguments
378    ///
379    /// * `lhs` - The left hand side tensor.
380    /// * `rhs` - The right hand side scalar.
381    ///
382    /// # Returns
383    ///
384    /// The boolean tensor with the result of the comparison.
385    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
386
387    // ====  NUMERIC ==== //
388
389    /// Element-wise addition.
390    ///
391    /// # Arguments
392    ///
393    /// * `lhs` - The left hand side tensor.
394    /// * `rhs` - The right hand side tensor.
395    ///
396    /// # Returns
397    ///
398    /// The result of the addition.
399    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
400
401    /// Element-wise addition with a scalar.
402    ///
403    /// # Arguments
404    ///
405    /// * `lhs` - The left hand side tensor.
406    /// * `rhs` - The right hand side scalar.
407    ///
408    /// # Returns
409    ///
410    /// The result of the addition.
411    fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
412
413    /// Element-wise power with a IntTensor.
414    ///
415    /// # Arguments
416    ///
417    /// * `lhs` - The left hand side IntTensor.
418    /// * `rhs` - The right hand side IntTensor.
419    ///
420    /// # Returns
421    ///
422    /// The elements of `lhs` raised to the power of the elements of `rhs`.
423    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
424        B::float_into_int(B::float_powf(
425            B::int_into_float(lhs),
426            B::int_into_float(rhs),
427        ))
428    }
429
430    /// Element-wise power with a floatTensor.
431    ///
432    /// # Arguments
433    ///
434    /// * `lhs` - The left hand side tensor.
435    /// * `rhs` - The right hand side floatTensor.
436    ///
437    /// # Returns
438    ///
439    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
440    fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
441        B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
442    }
443
444    /// Element-wise power with a scalar.
445    ///
446    /// # Arguments
447    ///
448    /// * `lhs` - The left hand side tensor.
449    /// * `rhs` - The right hand side scalar.
450    ///
451    /// # Returns
452    ///
453    /// The elements of `lhs` raised to the value of `rhs`.
454    fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
455        B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs.to_f32()))
456    }
457
458    /// Element-wise power with a floatTensor.
459    ///
460    /// # Arguments
461    ///
462    /// * `lhs` - The left hand side tensor.
463    /// * `rhs` - The right hand side scalar.
464    ///
465    /// # Returns
466    ///
467    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
468    fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
469        B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs))
470    }
471
472    /// Clamps a tensor under a minimum value.
473    ///
474    /// # Arguments
475    ///
476    /// * `tensor` - The tensor to clamp.
477    /// * `min` - The minimum value.
478    ///
479    /// # Returns
480    ///
481    /// The clamped tensor.
482    fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
483        let mask = Self::int_lower_elem(tensor.clone(), min);
484        Self::int_mask_fill(tensor, mask, min)
485    }
486
487    /// Clamps a tensor over a maximum value.
488    ///
489    /// # Arguments
490    ///
491    /// * `tensor` - The tensor to clamp.
492    /// * `max` - The maximum value.
493    ///
494    /// # Returns
495    ///
496    /// The clamped tensor.
497    fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
498        let mask = Self::int_greater_elem(tensor.clone(), max);
499        Self::int_mask_fill(tensor, mask, max)
500    }
501
502    /// Clamps a tensor between a minimum and maximum value.
503    ///
504    /// # Arguments
505    ///
506    /// * `tensor` - The tensor to clamp.
507    /// * `min` - The minimum value.
508    /// * `max` - The maximum value.
509    ///
510    /// # Returns
511    ///
512    /// The clamped tensor.
513    fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
514        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
515    }
516
517    /// Element-wise subtraction.
518    ///
519    /// # Arguments
520    ///
521    /// * `lhs` - The left hand side tensor.
522    /// * `rhs` - The right hand side tensor.
523    ///
524    /// # Returns
525    ///
526    /// The result of the subtraction.
527    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
528
529    /// Element-wise subtraction with a scalar.
530    ///
531    /// # Arguments
532    ///
533    /// * `lhs` - The left hand side tensor.
534    /// * `rhs` - The right hand side scalar.
535    ///
536    /// # Returns
537    ///
538    /// The result of the subtraction.
539    fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
540
541    /// Element-wise multiplication.
542    ///
543    /// # Arguments
544    ///
545    /// * `lhs` - The left hand side tensor.
546    /// * `rhs` - The right hand side tensor.
547    ///
548    /// # Returns
549    ///
550    /// The result of the multiplication.
551    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
552
553    /// Element-wise multiplication with a scalar.
554    ///
555    /// # Arguments
556    ///
557    /// * `lhs` - The left hand side tensor.
558    /// * `rhs` - The right hand side scalar.
559    ///
560    /// # Returns
561    ///
562    /// The result of the multiplication.
563    fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
564
565    /// Element-wise division.
566    ///
567    /// # Arguments
568    ///
569    /// * `lhs` - The left hand side tensor.
570    /// * `rhs` - The right hand side tensor.
571    ///
572    /// # Returns
573    ///
574    /// The result of the division.
575    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
576
577    /// Element-wise division with a scalar.
578    ///
579    /// # Arguments
580    ///
581    /// * `lhs` - The left hand side tensor.
582    /// * `rhs` - The right hand side scalar.
583    ///
584    /// # Returns
585    ///
586    /// The result of the division.
587    fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
588
589    /// Element-wise modulus.
590    ///
591    /// # Arguments
592    /// * `lhs` - The left hand side tensor.
593    /// * `rhs` - The right hand side scalar.
594    ///
595    /// # Returns
596    ///
597    /// The result of applying the modulus of the scalar to the tensor.
598    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
599
600    /// Element-wise modulus with a scalar.
601    ///
602    /// # Arguments
603    /// * `lhs` - The left hand side tensor.
604    /// * `rhs` - The right hand side scalar.
605    ///
606    /// # Returns
607    ///
608    /// The result of applying the modulus of the scalar to the tensor.
609    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
610
611    /// Element-wise negation.
612    ///
613    /// # Arguments
614    ///
615    /// * `tensor` - The tensor to negate.
616    ///
617    /// # Returns
618    ///
619    /// The negated tensor.
620    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
621        Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
622    }
623
624    /// Creates a tensor of zeros.
625    ///
626    /// # Arguments
627    ///
628    /// * `shape` - The shape of the tensor.
629    /// * `device` - The device to create the tensor on.
630    ///
631    /// # Returns
632    ///
633    /// The tensor of zeros.
634    fn int_zeros(shape: Shape, device: &Device<B>) -> IntTensor<B>;
635
636    /// Creates a tensor of ones.
637    ///
638    /// # Arguments
639    ///
640    /// * `shape` - The shape of the tensor.
641    /// * `device` - The device to create the tensor on.
642    ///
643    /// # Returns
644    ///
645    /// The tensor of ones.
646    fn int_ones(shape: Shape, device: &Device<B>) -> IntTensor<B>;
647
648    /// Creates a tensor filled with given value.
649    ///
650    /// # Arguments
651    ///
652    /// * `shape` - The shape of the tensor.
653    /// * `fill_value` - The value with which to fill the tensor.
654    /// * `device` - The device to create the tensor on.
655    ///
656    /// # Returns
657    ///
658    /// The tensor filled with given value
659    fn int_full(shape: Shape, fill_value: IntElem<B>, device: &Device<B>) -> IntTensor<B> {
660        Self::int_add_scalar(Self::int_zeros(shape, device), fill_value)
661    }
662
663    /// Sums all elements in the tensor.
664    ///
665    /// # Arguments
666    ///
667    /// * `tensor` - The tensor to sum.
668    ///
669    /// # Returns
670    ///
671    /// The sum of all elements in the tensor.
672    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
673
674    /// Sums all elements in the tensor along a dimension.
675    ///
676    /// # Arguments
677    ///
678    /// * `tensor` - The tensor to sum.
679    /// * `dim` - The dimension to sum along.
680    ///
681    /// # Returns
682    ///
683    /// The sum of all elements in the tensor along the dimension.
684    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
685
686    /// Computes the product of all elements in the tensor.
687    ///
688    /// # Arguments
689    ///
690    /// * `tensor` - The tensor to compute the product of.
691    ///
692    /// # Returns
693    ///
694    /// The product of all elements in the tensor.
695    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
696
697    /// Computes the product of all elements in the tensor along a dimension.
698    ///
699    /// # Arguments
700    ///
701    /// * `tensor` - The tensor to compute the product of.
702    /// * `dim` - The dimension to compute the product along.
703    ///
704    /// # Returns
705    ///
706    /// The product of all elements in the tensor along the dimension.
707    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
708
709    /// Computes the mean of all elements in the tensor.
710    ///
711    /// # Arguments
712    ///
713    /// * `tensor` - The tensor to compute the mean of.
714    ///
715    /// # Returns
716    ///
717    /// The mean of all elements in the tensor.
718    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
719        let num_elems = tensor.shape().num_elements();
720        B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
721    }
722
723    /// Computes the mean of all elements in the tensor along a dimension.
724    ///
725    /// # Arguments
726    ///
727    /// * `tensor` - The tensor to compute the mean of.
728    ///
729    /// # Returns
730    ///
731    /// The mean of all elements in the tensor along the dimension.
732    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
733
734    /// Gets the indices of the maximum elements along a dimension.
735    ///
736    /// # Arguments
737    ///
738    /// * `tensor` - The tensor to get the maximum indices of.
739    /// * `dim` - The dimension to get the maximum indices along.
740    ///
741    /// # Returns
742    ///
743    /// The indices of the maximum elements along the dimension.
744    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
745
746    /// Gets the indices of the minimum elements along a dimension.
747    ///
748    /// # Arguments
749    ///
750    /// * `tensor` - The tensor to get the minimum indices of.
751    /// * `dim` - The dimension to get the minimum indices along.
752    ///
753    /// # Returns
754    ///
755    /// The indices of the minimum elements along the dimension.
756    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
757
758    /// Gets the maximum element in the tensor.
759    ///
760    /// # Arguments
761    ///
762    /// * `tensor` - The tensor to get the maximum element of.
763    ///
764    /// # Returns
765    ///
766    /// The maximum element in the tensor.
767    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
768        let shape = tensor.shape();
769        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
770
771        B::int_max_dim(tensor, 0)
772    }
773
774    /// Gets the maximum element in the tensor along a dimension.
775    ///
776    /// # Arguments
777    ///
778    /// * `tensor` - The tensor to get the maximum element of.
779    /// * `dim` - The dimension to get the maximum element along.
780    ///
781    /// # Returns
782    ///
783    /// The maximum element in the tensor along the dimension.
784    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
785        let index = B::int_argmax(tensor.clone(), dim);
786        B::int_gather(dim, tensor, index)
787    }
788
789    /// Gets the maximum elements and corresponding indices along a dimension.
790    ///
791    /// # Arguments
792    ///
793    /// * `tensor` - The tensor to get the maximum elements and indices of.
794    /// * `dim` - The dimension to get the maximum elements and indices along.
795    ///
796    /// # Returns
797    ///
798    /// The maximum elements and corresponding indices along the dimension.
799    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
800        let index = B::int_argmax(tensor.clone(), dim);
801        let values = B::int_gather(dim, tensor, index.clone());
802
803        (values, index)
804    }
805
806    /// Gets the maximum absolute element in the tensor.
807    ///
808    /// # Arguments
809    ///
810    /// * `tensor` - The tensor to get the maximum element of.
811    ///
812    /// # Returns
813    ///
814    /// The maximum element in the tensor.
815    fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
816        let shape = tensor.shape();
817        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
818
819        B::int_max_abs_dim(tensor, 0)
820    }
821
822    /// Gets the maximum absolute element in the tensor along a dimension.
823    ///
824    /// # Arguments
825    ///
826    /// * `tensor` - The tensor to get the maximum element of.
827    /// * `dim` - The dimension to get the maximum element along.
828    ///
829    /// # Returns
830    ///
831    /// The maximum element in the tensor along the dimension.
832    fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
833        B::int_max_dim(B::int_abs(tensor), dim)
834    }
835
836    /// Gets the minimum element in the tensor.
837    ///
838    /// # Arguments
839    ///
840    /// * `tensor` - The tensor to get the minimum element of.
841    ///
842    /// # Returns
843    ///
844    /// The minimum element in the tensor.
845    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
846        let shape = tensor.shape();
847        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
848
849        B::int_min_dim(tensor, 0)
850    }
851
852    /// Gets the minimum elements in the tensor along a dimension.
853    ///
854    /// # Arguments
855    ///
856    /// * `tensor` - The tensor to get the minimum element of.
857    /// * `dim` - The dimension to get the minimum element along.
858    ///
859    /// # Returns
860    ///
861    /// The minimum element in the tensor along the dimension.
862    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
863        let index = B::int_argmin(tensor.clone(), dim);
864        B::int_gather(dim, tensor, index)
865    }
866
867    /// Gets the minimum elements and corresponding indices along a dimension.
868    ///
869    /// # Arguments
870    ///
871    /// * `tensor` - The tensor to get the minimum elements and indices of.
872    /// * `dim` - The dimension to get the minimum elements and indices along.
873    ///
874    /// # Returns
875    ///
876    /// The minimum elements and corresponding indices along the dimension.
877    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
878        let indices = B::int_argmin(tensor.clone(), dim);
879        let values = B::int_gather(dim, tensor, indices.clone());
880
881        (values, indices)
882    }
883
884    /// Returns a new tensor with absolute values.
885    ///
886    /// # Arguments
887    ///
888    /// * `tensor` - The tensor to take absolute value of.
889    ///
890    /// # Returns
891    ///
892    /// A tensor with the same shape as `tensor` with absolute values.
893    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
894
895    /// Transposes an int tensor.
896    ///
897    /// # Arguments
898    ///
899    /// * `tensor` - The tensor to transpose.
900    ///
901    /// # Returns
902    ///
903    /// The transposed tensor.
904    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
905        let ndims = tensor.shape().num_dims();
906        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
907    }
908
909    /// Swaps two dimensions of an int tensor.
910    ///
911    /// # Arguments
912    ///
913    /// * `tensor` - The tensor to swap the dimensions of.
914    /// * `dim1` - The first dimension to swap.
915    /// * `dim2` - The second dimension to swap.
916    ///
917    /// # Returns
918    ///
919    /// The tensor with the dimensions swapped.
920    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
921
922    /// Permutes the dimensions of a tensor.
923    ///
924    /// # Arguments
925    ///
926    /// * `tensor` - The tensor to permute the dimensions of.
927    /// * `axes` - The new order of the dimensions.
928    /// # Returns
929    ///
930    /// The tensor with the dimensions permuted.
931    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
932
933    /// Reverse the order of elements in a tensor along the given axes.
934    ///
935    /// # Arguments
936    ///
937    /// * `tensor` - The tensor to reverse.
938    /// * `axes` - The axes to reverse.
939    ///
940    /// The tensor with the elements reversed.
941    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
942
943    /// Returns a new tensor with the given dimension narrowed to the given range.
944    ///
945    /// # Arguments
946    ///
947    /// * `dim` - The dimension along which the tensor will be narrowed.
948    /// * `start` - The starting point of the given range.
949    /// * `length` - The ending point of the given range.
950    /// # Panics
951    ///
952    /// - If the dimension is greater than the number of dimensions of the tensor.
953    /// - If the given range exceeds the number of elements on the given dimension.
954    ///
955    /// # Returns
956    ///
957    /// A new tensor with the given dimension narrowed to the given range.
958    fn int_narrow(tensor: IntTensor<B>, dim: usize, start: usize, length: usize) -> IntTensor<B> {
959        narrow::<B, Int>(tensor, dim, start, length)
960    }
961
962    /// Split the tensor along the given dimension into chunks.
963    ///
964    /// # Arguments
965    ///
966    /// * `tensor` - The tensor.
967    /// * `chunks` - The number of chunks to be produced.
968    /// * `times` - The dimension along which the tensor will be split.
969    ///
970    /// # Returns
971    ///
972    /// A vector of tensors
973    fn int_chunk(tensor: IntTensor<B>, chunks: usize, dim: usize) -> Vec<IntTensor<B>> {
974        chunk::<B, Int>(tensor, chunks, dim)
975    }
976
977    /// Split the tensor along the given dimension into chunks of `split_size`.
978    ///
979    /// # Arguments
980    ///
981    /// * `tensor` - The tensor.
982    /// * `split_size` - The size of a single chunk.
983    /// * `times` - The dimension along which the tensor will be split.
984    ///
985    /// # Returns
986    ///
987    /// A vector of tensors.
988    fn int_split(tensor: IntTensor<B>, split_size: usize, dim: usize) -> Vec<IntTensor<B>> {
989        split::<B, Int>(tensor, split_size, dim)
990    }
991
992    /// Split the tensor along the given dimension into chunks with sizes in
993    /// `dim` according to `split_sizes`.
994    ///
995    /// # Arguments
996    ///
997    /// * `tensor` - The tensor.
998    /// * `split_sizes` - Vector of sizes for each chunk.
999    /// * `times` - The dimension along which the tensor will be split.
1000    ///
1001    /// # Returns
1002    ///
1003    /// A vector of tensors.
1004    fn int_split_with_sizes(
1005        tensor: IntTensor<B>,
1006        split_sizes: Vec<usize>,
1007        dim: usize,
1008    ) -> Vec<IntTensor<B>> {
1009        split_with_sizes::<B, Int>(tensor, split_sizes, dim)
1010    }
1011
1012    /// Creates a new int tensor with random values.
1013    ///
1014    ///  # Arguments
1015    ///  * `shape` - The shape of the tensor.
1016    ///  * `distribution` - The distribution to sample from.
1017    ///  * `device` - The device to create the tensor on.
1018    ///
1019    ///  # Returns
1020    ///
1021    ///  The tensor with the given shape and random values.
1022    fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
1023
1024    /// Creates a new tensor with values from the given range with the given step size.
1025    ///
1026    /// # Arguments
1027    ///
1028    /// * `range` - The range of values.
1029    /// * `step` - The step size.
1030    /// * `device` - The device to create the tensor on.
1031    ///
1032    /// # Returns
1033    ///
1034    /// The tensor with the given values.
1035    fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1036        let value = range
1037            .step_by(step)
1038            .map(|i| i.elem())
1039            .collect::<Vec<IntElem<B>>>();
1040        let shape = Shape::new([value.len()]);
1041        let data = TensorData::new(value, shape);
1042        B::int_from_data(data, device)
1043    }
1044
1045    /// Creates a new tensor with values from the given range.
1046    ///
1047    /// # Arguments
1048    ///
1049    /// * `range` - The range of values.
1050    /// * `device` - The device to create the tensor on.
1051    ///
1052    /// # Returns
1053    ///
1054    /// The tensor with the given values.
1055    ///
1056    /// # Remarks
1057    ///
1058    /// Uses `arange_step` with a step size of 1 under the hood.
1059    fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1060        Self::int_arange_step(range, 1, device)
1061    }
1062
1063    /// Tests if any element in the int `tensor` evaluates to True.
1064    ///
1065    /// # Arguments
1066    ///
1067    /// * `tensor` - The tensor to test.
1068    ///
1069    /// # Returns
1070    ///
1071    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1072    fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1073        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1074        let bool_tensor = B::bool_not(bool_tensor);
1075        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1076        B::int_greater_elem(sum, 0.elem())
1077    }
1078
1079    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1080    ///
1081    /// # Arguments
1082    ///
1083    /// * `tensor` - The tensor to test.
1084    /// * `dim` - The axis along which to test.
1085    ///
1086    /// # Returns
1087    ///
1088    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1089    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1090    /// evaluates to True, False otherwise.
1091    fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1092        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1093        let bool_tensor = B::bool_not(bool_tensor);
1094        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1095        B::int_greater_elem(sum, 0.elem())
1096    }
1097
1098    /// Tests if all elements in the int `tensor` evaluate to True.
1099    ///
1100    /// # Arguments
1101    ///
1102    /// * `tensor` - The tensor to test.
1103    ///
1104    /// # Returns
1105    ///
1106    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1107    /// evaluate to True, False otherwise.
1108    fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1109        let num_elems = tensor.shape().num_elements();
1110        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1111        let bool_tensor = B::bool_not(bool_tensor);
1112        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1113        B::int_equal_elem(sum, (num_elems as i32).elem())
1114    }
1115
1116    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1117    ///
1118    /// # Arguments
1119    ///
1120    /// * `tensor` - The tensor to test.
1121    /// * `dim` - The axis along which to test.
1122    ///
1123    /// # Returns
1124    ///
1125    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1126    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1127    /// evaluates to True, False otherwise.
1128    fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1129        let num_elems = tensor.shape().dims[dim];
1130        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1131        let bool_tensor = B::bool_not(bool_tensor);
1132        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1133        B::int_equal_elem(sum, (num_elems as i32).elem())
1134    }
1135
1136    /// Returns the signs of the int `tensor`.
1137    ///
1138    /// # Arguments
1139    ///
1140    /// * `tensor` - The tensor to extract the signs from.
1141    ///
1142    /// # Returns
1143    ///
1144    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1145    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1146        let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor));
1147        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1148        let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1149
1150        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1151        result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1152        result
1153    }
1154
1155    /// Broadcasts the int `tensor` to the given `shape`.
1156    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1157
1158    /// Sort the elements of the input `tensor` by value along a given dimension.
1159    ///
1160    /// This sort is unstable (i.e., may reorder equal elements).
1161    ///
1162    /// # Arguments
1163    ///
1164    /// * `tensor` - The input tensor.
1165    /// * `dim` - The axis along which to sort.
1166    /// * `descending` - The sorting order.
1167    ///
1168    /// # Returns
1169    ///
1170    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1171    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1172        sort::<B, Int>(tensor, dim, descending)
1173    }
1174
1175    /// Sort the elements of the input `tensor` by value along a given dimension.
1176    ///
1177    /// This sort is unstable (i.e., may reorder equal elements).
1178    ///
1179    /// # Arguments
1180    ///
1181    /// * `tensor` - The input tensor.
1182    /// * `dim` - The axis along which to sort.
1183    ///
1184    /// # Returns
1185    ///
1186    /// A tensor with the same shape as the input tensor and corresponding indices, where
1187    /// the elements are sorted by value and the indices map back to the original input tensor.
1188    fn int_sort_with_indices(
1189        tensor: IntTensor<B>,
1190        dim: usize,
1191        descending: bool,
1192    ) -> (IntTensor<B>, IntTensor<B>) {
1193        sort_with_indices::<B, Int>(tensor, dim, descending)
1194    }
1195
1196    /// Returns the indices that sort the elements of the input `tensor` by value
1197    /// along a given dimension.
1198    ///
1199    /// This sort is unstable (i.e., may reorder equal elements).
1200    ///
1201    /// # Arguments
1202    ///
1203    /// * `tensor` - The input tensor.
1204    /// * `dim` - The axis along which to sort.
1205    /// * `descending` - The sorting order.
1206    ///
1207    /// # Returns
1208    ///
1209    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1210    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1211        argsort::<B, Int>(tensor, dim, descending)
1212    }
1213
1214    /// Bitwise AND operation for Int Tensors
1215    fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1216
1217    /// Bitwise AND operation for Int Tensors with a scalar
1218    fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1219
1220    /// Bitwise OR operation for Int Tensors
1221    fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1222
1223    /// Bitwise OR operation for Int Tensors with a scalar
1224    fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1225
1226    /// Bitwise XOR operation for Int Tensors
1227    fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1228
1229    /// Bitwise XOR operation for Int Tensors with a scalar
1230    fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1231
1232    /// Bitwise NOT operation for Int Tensors
1233    fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1234
1235    /// Bitwise left shift operation for Int Tensors
1236    fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1237
1238    /// Bitwise left shift operation for Int Tensors with a scalar
1239    fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1240
1241    /// Bitwise right shift operation for Int Tensors
1242    fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1243
1244    /// Bitwise right shift operation for Int Tensors with a scalar
1245    fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1246}