Skip to main content

burn_backend/backend/ops/
int_tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::sort::{argsort, sort, sort_with_indices};
4use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};
5use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};
6use crate::{ExecutionError, Scalar, get_device_settings};
7use alloc::vec::Vec;
8use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16    /// Creates a new int tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    /// * `dtype` - The target data type.
23    ///
24    /// # Returns
25    ///
26    /// The integer tensor with the given shape.
27    fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
28
29    /// Converts the tensor to a data structure.
30    ///
31    /// # Arguments
32    ///
33    /// * `tensor` - The tensor.
34    ///
35    /// # Returns
36    ///
37    /// The data structure with the tensor's data.
38    fn int_into_data(
39        tensor: IntTensor<B>,
40    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
41
42    /// Creates a tensor from the data structure.
43    ///
44    /// # Arguments
45    ///
46    /// * `data` - The data structure.
47    /// * `device` - The device to create the tensor on.
48    ///
49    /// # Returns
50    ///
51    /// The tensor with the data.
52    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
53
54    /// Gets the device of the tensor.
55    ///
56    /// # Arguments
57    ///
58    /// * `tensor` - The tensor.
59    ///
60    /// # Returns
61    ///
62    /// The device of the tensor.
63    fn int_device(tensor: &IntTensor<B>) -> Device<B>;
64
65    /// Moves the tensor to the given device.
66    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
67
68    /// Reshapes the tensor.
69    ///
70    /// # Arguments
71    ///
72    /// * `tensor` - The tensor.
73    /// * `shape` - The new shape.
74    ///
75    /// # Returns
76    ///
77    /// The tensor with the new shape.
78    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
79
80    /// Gets the element at the given indices.
81    ///
82    /// # Arguments
83    ///
84    /// * `tensor` - The tensor.
85    /// * `slices` - The slices specifying ranges and steps for each dimension.
86    ///
87    /// # Returns
88    ///
89    /// The elements at the given indices.
90    ///
91    /// # Note
92    ///
93    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
94    /// be passed to this method. Backend implementations do not need to handle empty slices.
95    fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;
96
97    /// Sets the values in the tensor for the given ranges.
98    ///
99    /// # Arguments
100    ///
101    /// * `tensor` - The tensor.
102    /// * `ranges` - The ranges to set the values for.
103    ///
104    /// # Returns
105    ///
106    /// The tensor with the values set for the given ranges.
107    ///
108    /// # Note
109    ///
110    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
111    /// high-level tensor API and will not be passed to this method. Backend implementations do
112    /// not need to handle empty slice assignments.
113    fn int_slice_assign(
114        tensor: IntTensor<B>,
115        slices: &[Slice],
116        value: IntTensor<B>,
117    ) -> IntTensor<B>;
118
119    /// Converts int tensor to float tensor.
120    ///
121    /// # Arguments
122    ///
123    /// * `tensor` - The tensor.
124    /// * `out_dtype` - The output tensor dtype.
125    ///
126    /// # Returns
127    ///
128    /// The int tensor with the same data as the float tensor.
129    fn int_into_float(tensor: IntTensor<B>, out_dtype: FloatDType) -> FloatTensor<B>;
130
131    /// Fills the tensor with values from the value tensor if the mask is true at the given
132    /// indices.
133    ///
134    /// # Arguments
135    ///
136    /// * `tensor` - The tensor.
137    /// * `mask` - The mask.
138    /// * `value` - The value tensor.
139    ///
140    /// # Returns
141    ///
142    /// The tensor with the values filled.
143    fn int_mask_where(
144        tensor: IntTensor<B>,
145        mask: BoolTensor<B>,
146        value: IntTensor<B>,
147    ) -> IntTensor<B>;
148
149    /// Fills the tensor with the given value if the mask is true at the given indices.
150    ///
151    /// # Arguments
152    ///
153    /// * `tensor` - The tensor.
154    /// * `mask` - The mask.
155    /// * `value` - The value.
156    ///
157    /// # Returns
158    ///
159    /// The tensor with the values filled.
160    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: Scalar) -> IntTensor<B>;
161
162    /// Gather elements from the tensor at the given indices.
163    ///
164    /// # Arguments
165    ///
166    /// * `dim` - The dimension to gather from.
167    /// * `tensor` - The tensor.
168    /// * `indices` - The indices.
169    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
170
171    /// Scatter a given value to the tensor at the given indices using sum reduction.
172    ///
173    /// # Arguments
174    ///
175    /// * `dim` - The dimension to scatter to.
176    /// * `tensor` - The tensor.
177    /// * `indices` - The indices.
178    /// * `value` - The value.
179    ///
180    /// # Returns
181    ///
182    /// The tensor with the values scattered.
183    fn int_scatter_add(
184        dim: usize,
185        tensor: IntTensor<B>,
186        indices: IntTensor<B>,
187        value: IntTensor<B>,
188    ) -> IntTensor<B>;
189
190    /// Multi-dimensional scatter for int tensors.
191    fn int_scatter_nd(
192        _data: IntTensor<B>,
193        _indices: IntTensor<B>,
194        _values: IntTensor<B>,
195        _reduction: crate::tensor::IndexingUpdateOp,
196    ) -> IntTensor<B> {
197        unimplemented!("int_scatter_nd is not implemented for this backend")
198    }
199
200    /// Multi-dimensional gather for int tensors.
201    fn int_gather_nd(_data: IntTensor<B>, _indices: IntTensor<B>) -> IntTensor<B> {
202        unimplemented!("int_gather_nd is not implemented for this backend")
203    }
204
205    /// Select tensor elements along the given dimension corresponding to the given indices.
206    ///
207    /// # Arguments
208    ///
209    /// * `tensor` - The tensor.
210    /// * `dim` - The dimension to select from.
211    /// * `indices` - The indices.
212    ///
213    /// # Returns
214    ///
215    /// The tensor with the selected elements.
216    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
217
218    /// Assign the selected elements along the given dimension corresponding to the given indices
219    /// to the given value using sum reduction.
220    ///
221    /// # Arguments
222    ///
223    /// * `tensor` - The tensor.
224    /// * `dim` - The dimension to select from.
225    /// * `indices` - The indices.
226    /// * `value` - The value.
227    ///
228    /// # Returns
229    ///
230    /// The tensor with the selected elements assigned to the given value.
231    fn int_select_add(
232        tensor: IntTensor<B>,
233        dim: usize,
234        indices: IntTensor<B>,
235        value: IntTensor<B>,
236    ) -> IntTensor<B>;
237
238    /// Repeats the tensor along the given dimension the given number of times.
239    ///
240    /// # Arguments
241    ///
242    /// * `tensor` - The tensor.
243    /// * `dim` - The dimension to repeat.
244    /// * `times` - The number of times to repeat.
245    ///
246    /// # Returns
247    ///
248    /// The tensor with the given dimension repeated the given number of times.
249    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
250        repeat_with_slice_assign::<B, Int>(tensor, dim, times)
251    }
252
253    /// Concatenates the given tensors along the given dimension.
254    ///
255    /// # Arguments
256    ///
257    /// * `tensors` - The tensors.
258    /// * `dim` - The dimension to concatenate along.
259    ///
260    /// # Returns
261    ///
262    /// The concatenated tensor.
263    ///
264    /// # Note
265    ///
266    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
267    /// high-level tensor API and will not be passed to this method. Backend implementations do
268    /// not need to handle empty tensors.
269    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
270        cat_with_slice_assign::<B, Int>(tensors, dim)
271    }
272
273    /// Element-wise equality comparison.
274    ///
275    /// # Arguments
276    ///
277    /// * `lhs` - The left-hand side tensor.
278    /// * `rhs` - The right-hand side tensor.
279    /// * `out_dtype` - The output tensor dtype.
280    ///
281    /// # Returns
282    ///
283    /// The boolean tensor with the result of the comparison.
284    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
285
286    /// Element-wise non-equality comparison.
287    ///
288    /// # Arguments
289    ///
290    /// * `lhs` - The left-hand side tensor.
291    /// * `rhs` - The right-hand side tensor.
292    /// * `out_dtype` - The output tensor dtype.
293    ///
294    /// # Returns
295    ///
296    /// The boolean tensor with the result of the comparison.
297    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
298        let equal_tensor = B::int_equal(lhs, rhs, out_dtype);
299        B::bool_not(equal_tensor)
300    }
301
302    /// Element-wise equality comparison with a scalar.
303    ///
304    /// # Arguments
305    ///
306    /// * `lhs` - The left-hand side tensor.
307    /// * `rhs` - The right-hand side scalar.
308    /// * `out_dtype` - The output tensor dtype.
309    ///
310    /// # Returns
311    ///
312    /// The boolean tensor with the result of the comparison.
313    fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
314
315    /// Element-wise non-equality comparison with a scalar.
316    ///
317    /// # Arguments
318    ///
319    /// * `lhs` - The left-hand side tensor.
320    /// * `rhs` - The right-hand side scalar.
321    /// * `out_dtype` - The output tensor dtype.
322    ///
323    /// # Returns
324    ///
325    /// The boolean tensor with the result of the comparison.
326    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B> {
327        let equal_tensor = B::int_equal_elem(lhs, rhs, out_dtype);
328        B::bool_not(equal_tensor)
329    }
330
331    /// Element-wise greater than comparison.
332    ///
333    /// # Arguments
334    ///
335    /// * `lhs` - The left-hand side tensor.
336    /// * `rhs` - The right-hand side tensor.
337    /// * `out_dtype` - The output tensor dtype.
338    ///
339    /// # Returns
340    ///
341    /// The boolean tensor with the result of the comparison.
342    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
343
344    /// Element-wise greater than comparison with a scalar.
345    ///
346    /// # Arguments
347    ///
348    /// * `lhs` - The left-hand side tensor.
349    /// * `rhs` - The right-hand side scalar.
350    /// * `out_dtype` - The output tensor dtype.
351    ///
352    /// # Returns
353    ///
354    /// The boolean tensor with the result of the comparison.
355    fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
356
357    /// Element-wise greater than or equal comparison.
358    ///
359    /// # Arguments
360    ///
361    /// * `lhs` - The left-hand side tensor.
362    /// * `rhs` - The right-hand side tensor.
363    /// * `out_dtype` - The output tensor dtype.
364    ///
365    /// # Returns
366    ///
367    /// The boolean tensor with the result of the comparison.
368    fn int_greater_equal(
369        lhs: IntTensor<B>,
370        rhs: IntTensor<B>,
371        out_dtype: BoolDType,
372    ) -> BoolTensor<B>;
373
374    /// Element-wise greater than or equal comparison with a scalar.
375    ///
376    /// # Arguments
377    ///
378    /// * `lhs` - The left-hand side tensor.
379    /// * `rhs` - The right-hand side scalar.
380    /// * `out_dtype` - The output tensor dtype.
381    ///
382    /// # Returns
383    ///
384    /// The boolean tensor with the result of the comparison.
385    fn int_greater_equal_elem(
386        lhs: IntTensor<B>,
387        rhs: Scalar,
388        out_dtype: BoolDType,
389    ) -> BoolTensor<B>;
390
391    /// Element-wise less than comparison.
392    ///
393    /// # Arguments
394    ///
395    /// * `lhs` - The left-hand side tensor.
396    /// * `rhs` - The right-hand side tensor.
397    /// * `out_dtype` - The output tensor dtype.
398    ///
399    /// # Returns
400    ///
401    /// The boolean tensor with the result of the comparison.
402    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
403
404    /// Element-wise less than comparison with a scalar.
405    ///
406    /// # Arguments
407    ///
408    /// * `lhs` - The left-hand side tensor.
409    /// * `rhs` - The right-hand side scalar.
410    /// * `out_dtype` - The output tensor dtype.
411    ///
412    /// # Returns
413    ///
414    /// The boolean tensor with the result of the comparison.
415    fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
416
417    /// Element-wise less than or equal comparison.
418    ///
419    /// # Arguments
420    ///
421    /// * `lhs` - The left-hand side tensor.
422    /// * `rhs` - The right-hand side tensor.
423    /// * `out_dtype` - The output tensor dtype.
424    ///
425    /// # Returns
426    ///
427    /// The boolean tensor with the result of the comparison.
428    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType)
429    -> BoolTensor<B>;
430
431    /// Element-wise less than or equal comparison with a scalar.
432    ///
433    /// # Arguments
434    ///
435    /// * `lhs` - The left-hand side tensor.
436    /// * `rhs` - The right-hand side scalar.
437    /// * `out_dtype` - The output tensor dtype.
438    ///
439    /// # Returns
440    ///
441    /// The boolean tensor with the result of the comparison.
442    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
443
444    // ====  NUMERIC ==== //
445
446    /// Element-wise addition.
447    ///
448    /// # Arguments
449    ///
450    /// * `lhs` - The left-hand side tensor.
451    /// * `rhs` - The right-hand side tensor.
452    ///
453    /// # Returns
454    ///
455    /// The result of the addition.
456    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
457
458    /// Element-wise addition with a scalar.
459    ///
460    /// # Arguments
461    ///
462    /// * `lhs` - The left-hand side tensor.
463    /// * `rhs` - The right-hand side scalar.
464    ///
465    /// # Returns
466    ///
467    /// The result of the addition.
468    fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
469
470    /// Element-wise power with a IntTensor.
471    ///
472    /// # Arguments
473    ///
474    /// * `lhs` - The left-hand side IntTensor.
475    /// * `rhs` - The right-hand side IntTensor.
476    ///
477    /// # Returns
478    ///
479    /// The elements of `lhs` raised to the power of the elements of `rhs`.
480    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
481        let dtype = lhs.dtype();
482        let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
483        B::float_into_int(
484            B::float_powi(B::int_into_float(lhs, float_dtype), rhs),
485            dtype.into(),
486        )
487    }
488
489    /// Element-wise power with a scalar.
490    ///
491    /// # Backend Implementors Note
492    ///
493    /// A number of common exponent cases can be implemented with operations
494    /// which are much cheaper than generic exponentiation.
495    ///
496    /// This (`Backend` impl overridable) operation handles generic optimizations
497    /// for several common integer exponent cases; and then dispatches to
498    /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
499    /// operation to handle the generic case.
500    ///
501    /// # Arguments
502    ///
503    /// * `lhs` - The left-hand side tensor.
504    /// * `rhs` - The right-hand side scalar.
505    ///
506    /// # Returns
507    ///
508    /// The elements of `lhs` raised to the value of `rhs`.
509    fn int_powi_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
510        let exp = rhs.elem::<i32>();
511        match exp {
512            0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
513            1 => lhs,
514            2 => Self::int_mul(lhs.clone(), lhs),
515            _ => Self::int_powi_scalar_impl(lhs, rhs),
516        }
517    }
518
519    /// Element-wise power with a scalar.
520    ///
521    /// # Backend Implementors Note
522    ///
523    /// This is the generic implementation of integer exponentiation
524    /// called by [`Self::int_powi_scalar`] in the fallback case.
525    ///
526    /// By default, this performs a relatively expensive conversion to float,
527    /// exponentiation in float, and conversion back to int.
528    /// This reduces the minimal operation set for `Backend`s,
529    /// at the cost of performance.
530    ///
531    /// This is a good target for specialized optimizations in `Backend` implementations.
532    ///
533    /// As a general rule, this should not be called directly.
534    ///
535    /// # Arguments
536    ///
537    /// * `lhs` - The left-hand side tensor.
538    /// * `rhs` - The right-hand side scalar.
539    ///
540    /// # Returns
541    ///
542    /// The elements of `lhs` raised to the value of `rhs`.
543    fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
544        let dtype = lhs.dtype();
545        let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
546        B::float_into_int(
547            B::float_powi_scalar_impl(B::int_into_float(lhs, float_dtype), rhs),
548            dtype.into(),
549        )
550    }
551
552    /// Clamps a tensor under a minimum value.
553    ///
554    /// # Arguments
555    ///
556    /// * `tensor` - The tensor to clamp.
557    /// * `min` - The minimum value.
558    ///
559    /// # Returns
560    ///
561    /// The clamped tensor.
562    fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
563        let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
564        let mask = Self::int_lower_elem(tensor.clone(), min, dtype);
565        Self::int_mask_fill(tensor, mask, min)
566    }
567
568    /// Clamps a tensor over a maximum value.
569    ///
570    /// # Arguments
571    ///
572    /// * `tensor` - The tensor to clamp.
573    /// * `max` - The maximum value.
574    ///
575    /// # Returns
576    ///
577    /// The clamped tensor.
578    fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
579        let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
580        let mask = Self::int_greater_elem(tensor.clone(), max, dtype);
581        Self::int_mask_fill(tensor, mask, max)
582    }
583
584    /// Clamps a tensor between a minimum and maximum value.
585    ///
586    /// # Arguments
587    ///
588    /// * `tensor` - The tensor to clamp.
589    /// * `min` - The minimum value.
590    /// * `max` - The maximum value.
591    ///
592    /// # Returns
593    ///
594    /// The clamped tensor.
595    fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
596        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
597    }
598
599    /// Element-wise subtraction.
600    ///
601    /// # Arguments
602    ///
603    /// * `lhs` - The left-hand side tensor.
604    /// * `rhs` - The right-hand side tensor.
605    ///
606    /// # Returns
607    ///
608    /// The result of the subtraction.
609    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
610
611    /// Element-wise subtraction with a scalar.
612    ///
613    /// # Arguments
614    ///
615    /// * `lhs` - The left-hand side tensor.
616    /// * `rhs` - The right-hand side scalar.
617    ///
618    /// # Returns
619    ///
620    /// The result of the subtraction.
621    fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
622
623    /// Element-wise multiplication.
624    ///
625    /// # Arguments
626    ///
627    /// * `lhs` - The left-hand side tensor.
628    /// * `rhs` - The right-hand side tensor.
629    ///
630    /// # Returns
631    ///
632    /// The result of the multiplication.
633    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
634
635    /// Element-wise multiplication with a scalar.
636    ///
637    /// # Arguments
638    ///
639    /// * `lhs` - The left-hand side tensor.
640    /// * `rhs` - The right-hand side scalar.
641    ///
642    /// # Returns
643    ///
644    /// The result of the multiplication.
645    fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
646
647    /// Element-wise division.
648    ///
649    /// # Arguments
650    ///
651    /// * `lhs` - The left-hand side tensor.
652    /// * `rhs` - The right-hand side tensor.
653    ///
654    /// # Returns
655    ///
656    /// The result of the division.
657    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
658
659    /// Element-wise division with a scalar.
660    ///
661    /// # Arguments
662    ///
663    /// * `lhs` - The left-hand side tensor.
664    /// * `rhs` - The right-hand side scalar.
665    ///
666    /// # Returns
667    ///
668    /// The result of the division.
669    fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
670
671    /// Element-wise modulus.
672    ///
673    /// # Arguments
674    /// * `lhs` - The left-hand side tensor.
675    /// * `rhs` - The right-hand side scalar.
676    ///
677    /// # Returns
678    ///
679    /// The result of applying the modulus of the scalar to the tensor.
680    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
681
682    /// Element-wise modulus with a scalar.
683    ///
684    /// # Arguments
685    /// * `lhs` - The left-hand side tensor.
686    /// * `rhs` - The right-hand side scalar.
687    ///
688    /// # Returns
689    ///
690    /// The result of applying the modulus of the scalar to the tensor.
691    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
692
693    /// Multiplies two tensors together using matrix multiplication.
694    ///
695    /// # Arguments
696    ///
697    /// * `lhs` - The left-hand side tensor.
698    /// * `rhs` - The right-hand side tensor.
699    ///
700    /// # Returns
701    ///
702    /// The result of multiplying the two tensors together using matrix multiplication.
703    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
704
705    /// Element-wise negation.
706    ///
707    /// # Arguments
708    ///
709    /// * `tensor` - The tensor to negate.
710    ///
711    /// # Returns
712    ///
713    /// The negated tensor.
714    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
715        Self::int_mul_scalar(tensor, (-1).into())
716    }
717
718    /// Creates a tensor of zeros.
719    ///
720    /// # Arguments
721    ///
722    /// * `shape` - The shape of the tensor.
723    /// * `device` - The device to create the tensor on.
724    /// * `dtype` - The target data type.
725    ///
726    /// # Returns
727    ///
728    /// The tensor of zeros.
729    fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
730        Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
731    }
732
733    /// Creates a tensor of ones.
734    ///
735    /// # Arguments
736    ///
737    /// * `shape` - The shape of the tensor.
738    /// * `device` - The device to create the tensor on.
739    /// * `dtype` - The target data type.
740    ///
741    /// # Returns
742    ///
743    /// The tensor of ones.
744    fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
745        Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
746    }
747
748    /// Creates a tensor filled with given value.
749    ///
750    /// # Arguments
751    ///
752    /// * `shape` - The shape of the tensor.
753    /// * `fill_value` - The value with which to fill the tensor.
754    /// * `device` - The device to create the tensor on.
755    /// * `dtype` - The target data type.
756    ///
757    /// # Returns
758    ///
759    /// The tensor filled with given value
760    fn int_full(
761        shape: Shape,
762        fill_value: Scalar,
763        device: &Device<B>,
764        dtype: IntDType,
765    ) -> IntTensor<B> {
766        Self::int_from_data(
767            TensorData::full_dtype(shape, fill_value, dtype.into()),
768            device,
769        )
770    }
771
772    /// Sums all elements in the tensor.
773    ///
774    /// # Arguments
775    ///
776    /// * `tensor` - The tensor to sum.
777    ///
778    /// # Returns
779    ///
780    /// The sum of all elements in the tensor.
781    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
782
783    /// Sums all elements in the tensor along a dimension.
784    ///
785    /// # Arguments
786    ///
787    /// * `tensor` - The tensor to sum.
788    /// * `dim` - The dimension to sum along.
789    ///
790    /// # Returns
791    ///
792    /// The sum of all elements in the tensor along the dimension.
793    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
794
795    /// Computes the product of all elements in the tensor.
796    ///
797    /// # Arguments
798    ///
799    /// * `tensor` - The tensor to compute the product of.
800    ///
801    /// # Returns
802    ///
803    /// The product of all elements in the tensor.
804    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
805
806    /// Computes the product of all elements in the tensor along a dimension.
807    ///
808    /// # Arguments
809    ///
810    /// * `tensor` - The tensor to compute the product of.
811    /// * `dim` - The dimension to compute the product along.
812    ///
813    /// # Returns
814    ///
815    /// The product of all elements in the tensor along the dimension.
816    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
817
818    /// Computes the mean of all elements in the tensor.
819    ///
820    /// # Arguments
821    ///
822    /// * `tensor` - The tensor to compute the mean of.
823    ///
824    /// # Returns
825    ///
826    /// The mean of all elements in the tensor.
827    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
828        let num_elems = tensor.shape().num_elements() as i64;
829        B::int_div_scalar(B::int_sum(tensor), num_elems.into())
830    }
831
832    /// Computes the mean of all elements in the tensor along a dimension.
833    ///
834    /// # Arguments
835    ///
836    /// * `tensor` - The tensor to compute the mean of.
837    ///
838    /// # Returns
839    ///
840    /// The mean of all elements in the tensor along the dimension.
841    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
842
843    /// Computes the cumulative sum of elements along a dimension.
844    ///
845    /// # Arguments
846    ///
847    /// * `tensor` - The tensor to compute the cumulative sum of.
848    /// * `dim` - The dimension along which to compute the cumulative sum.
849    ///
850    /// # Returns
851    ///
852    /// A tensor with the same shape where each element is the cumulative sum
853    /// of all elements up to and including that position along the dimension.
854    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
855
856    /// Computes the cumulative product of elements along a dimension.
857    ///
858    /// # Arguments
859    ///
860    /// * `tensor` - The tensor to compute the cumulative product of.
861    /// * `dim` - The dimension along which to compute the cumulative product.
862    ///
863    /// # Returns
864    ///
865    /// A tensor with the same shape where each element is the cumulative product
866    /// of all elements up to and including that position along the dimension.
867    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
868
869    /// Computes the cumulative minimum of elements along a dimension.
870    ///
871    /// # Arguments
872    ///
873    /// * `tensor` - The tensor to compute the cumulative minimum of.
874    /// * `dim` - The dimension along which to compute the cumulative minimum.
875    ///
876    /// # Returns
877    ///
878    /// A tensor with the same shape where each element is the minimum
879    /// of all elements up to and including that position along the dimension.
880    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
881
882    /// Computes the cumulative maximum of elements along a dimension.
883    ///
884    /// # Arguments
885    ///
886    /// * `tensor` - The tensor to compute the cumulative maximum of.
887    /// * `dim` - The dimension along which to compute the cumulative maximum.
888    ///
889    /// # Returns
890    ///
891    /// A tensor with the same shape where each element is the maximum
892    /// of all elements up to and including that position along the dimension.
893    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
894
895    /// Gets the indices of the maximum elements along a dimension.
896    ///
897    /// # Arguments
898    ///
899    /// * `tensor` - The tensor to get the maximum indices of.
900    /// * `dim` - The dimension to get the maximum indices along.
901    ///
902    /// # Returns
903    ///
904    /// The indices of the maximum elements along the dimension.
905    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
906
907    /// Gets the indices of the k maximum elements along a dimension.
908    /// If two elements share the same value, it will be ordered by the lowest
909    /// coordinate
910    ///
911    /// # Arguments
912    ///
913    /// * `tensor` - The tensor to get the maximum indices of.
914    /// * `dim` - The dimension to get the maximum indices along.
915    /// * `k` - number of maximum elements.
916    ///
917    /// # Returns
918    ///
919    /// The indices of the maximum elements along the dimension.
920    fn int_argtopk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B>;
921
922    /// Gets the values of the k maximum elements along a dimension.
923    /// # Arguments
924    ///
925    /// * `tensor` - The tensor to get the maximum values of.
926    /// * `dim` - The dimension to get the maximum values along.
927    /// * `k` - number of maximum elements.
928    ///
929    /// # Returns
930    ///
931    /// The values of the maximum elements along the dimension.
932    fn int_topk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B>;
933
934    /// Gets the indices of the minimum elements along a dimension.
935    ///
936    /// # Arguments
937    ///
938    /// * `tensor` - The tensor to get the minimum indices of.
939    /// * `dim` - The dimension to get the minimum indices along.
940    ///
941    /// # Returns
942    ///
943    /// The indices of the minimum elements along the dimension.
944    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
945
946    /// Gets the maximum element in the tensor.
947    ///
948    /// # Arguments
949    ///
950    /// * `tensor` - The tensor to get the maximum element of.
951    ///
952    /// # Returns
953    ///
954    /// The maximum element in the tensor.
955    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
956        let shape = tensor.shape();
957        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
958
959        B::int_max_dim(tensor, 0)
960    }
961
962    /// Gets the maximum element in the tensor along a dimension.
963    ///
964    /// # Arguments
965    ///
966    /// * `tensor` - The tensor to get the maximum element of.
967    /// * `dim` - The dimension to get the maximum element along.
968    ///
969    /// # Returns
970    ///
971    /// The maximum element in the tensor along the dimension.
972    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
973        let index = B::int_argmax(tensor.clone(), dim);
974        B::int_gather(dim, tensor, index)
975    }
976
977    /// Gets the maximum elements and corresponding indices along a dimension.
978    ///
979    /// # Arguments
980    ///
981    /// * `tensor` - The tensor to get the maximum elements and indices of.
982    /// * `dim` - The dimension to get the maximum elements and indices along.
983    ///
984    /// # Returns
985    ///
986    /// The maximum elements and corresponding indices along the dimension.
987    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
988        let index = B::int_argmax(tensor.clone(), dim);
989        let values = B::int_gather(dim, tensor, index.clone());
990
991        (values, index)
992    }
993
994    /// Gets the maximum absolute element in the tensor.
995    ///
996    /// # Arguments
997    ///
998    /// * `tensor` - The tensor to get the maximum element of.
999    ///
1000    /// # Returns
1001    ///
1002    /// The maximum element in the tensor.
1003    fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
1004        let shape = tensor.shape();
1005        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1006
1007        B::int_max_abs_dim(tensor, 0)
1008    }
1009
1010    /// Gets the maximum absolute element in the tensor along a dimension.
1011    ///
1012    /// # Arguments
1013    ///
1014    /// * `tensor` - The tensor to get the maximum element of.
1015    /// * `dim` - The dimension to get the maximum element along.
1016    ///
1017    /// # Returns
1018    ///
1019    /// The maximum element in the tensor along the dimension.
1020    fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1021        B::int_max_dim(B::int_abs(tensor), dim)
1022    }
1023
1024    /// Gets the minimum element in the tensor.
1025    ///
1026    /// # Arguments
1027    ///
1028    /// * `tensor` - The tensor to get the minimum element of.
1029    ///
1030    /// # Returns
1031    ///
1032    /// The minimum element in the tensor.
1033    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
1034        let shape = tensor.shape();
1035        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1036
1037        B::int_min_dim(tensor, 0)
1038    }
1039
1040    /// Gets the minimum elements in the tensor along a dimension.
1041    ///
1042    /// # Arguments
1043    ///
1044    /// * `tensor` - The tensor to get the minimum element of.
1045    /// * `dim` - The dimension to get the minimum element along.
1046    ///
1047    /// # Returns
1048    ///
1049    /// The minimum element in the tensor along the dimension.
1050    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1051        let index = B::int_argmin(tensor.clone(), dim);
1052        B::int_gather(dim, tensor, index)
1053    }
1054
1055    /// Gets the minimum elements and corresponding indices along a dimension.
1056    ///
1057    /// # Arguments
1058    ///
1059    /// * `tensor` - The tensor to get the minimum elements and indices of.
1060    /// * `dim` - The dimension to get the minimum elements and indices along.
1061    ///
1062    /// # Returns
1063    ///
1064    /// The minimum elements and corresponding indices along the dimension.
1065    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1066        let indices = B::int_argmin(tensor.clone(), dim);
1067        let values = B::int_gather(dim, tensor, indices.clone());
1068
1069        (values, indices)
1070    }
1071
1072    /// Returns a new tensor with absolute values.
1073    ///
1074    /// # Arguments
1075    ///
1076    /// * `tensor` - The tensor to take absolute value of.
1077    ///
1078    /// # Returns
1079    ///
1080    /// A tensor with the same shape as `tensor` with absolute values.
1081    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1082
1083    /// Transposes an int tensor.
1084    ///
1085    /// # Arguments
1086    ///
1087    /// * `tensor` - The tensor to transpose.
1088    ///
1089    /// # Returns
1090    ///
1091    /// The transposed tensor.
1092    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1093        let ndims = tensor.shape().num_dims();
1094        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1095    }
1096
1097    /// Swaps two dimensions of an int tensor.
1098    ///
1099    /// # Arguments
1100    ///
1101    /// * `tensor` - The tensor to swap the dimensions of.
1102    /// * `dim1` - The first dimension to swap.
1103    /// * `dim2` - The second dimension to swap.
1104    ///
1105    /// # Returns
1106    ///
1107    /// The tensor with the dimensions swapped.
1108    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1109
1110    /// Permutes the dimensions of a tensor.
1111    ///
1112    /// # Arguments
1113    ///
1114    /// * `tensor` - The tensor to permute the dimensions of.
1115    /// * `axes` - The new order of the dimensions.
1116    /// # Returns
1117    ///
1118    /// The tensor with the dimensions permuted.
1119    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1120
1121    /// Reverse the order of elements in a tensor along the given axes.
1122    ///
1123    /// # Arguments
1124    ///
1125    /// * `tensor` - The tensor to reverse.
1126    /// * `axes` - The axes to reverse.
1127    ///
1128    /// The tensor with the elements reversed.
1129    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1130
1131    /// Creates a new int tensor with random values.
1132    ///
1133    ///  # Arguments
1134    ///  * `shape` - The shape of the tensor.
1135    ///  * `distribution` - The distribution to sample from.
1136    ///  * `device` - The device to create the tensor on.
1137    /// * `dtype` - The target data type.
1138    ///
1139    ///  # Returns
1140    ///
1141    ///  The tensor with the given shape and random values.
1142    fn int_random(
1143        shape: Shape,
1144        distribution: Distribution,
1145        device: &Device<B>,
1146        dtype: IntDType,
1147    ) -> IntTensor<B>;
1148
1149    /// Creates a new tensor with values from the given range with the given step size.
1150    ///
1151    /// # Arguments
1152    ///
1153    /// * `range` - The range of values.
1154    /// * `step` - The step size.
1155    /// * `device` - The device to create the tensor on.
1156    /// * `dtype` - The target data type.
1157    ///
1158    /// # Returns
1159    ///
1160    /// The tensor with the given values.
1161    fn int_arange_step(
1162        range: Range<i64>,
1163        step: usize,
1164        device: &Device<B>,
1165        dtype: IntDType,
1166    ) -> IntTensor<B> {
1167        let value = range
1168            .step_by(step)
1169            .map(|i| i.elem())
1170            .collect::<Vec<IntElem<B>>>();
1171        let shape = Shape::new([value.len()]);
1172        let data = TensorData::new(value, shape).convert_dtype(dtype.into());
1173        B::int_from_data(data, device)
1174    }
1175
1176    /// Creates a new tensor with values from the given range.
1177    ///
1178    /// # Arguments
1179    ///
1180    /// * `range` - The range of values.
1181    /// * `device` - The device to create the tensor on.
1182    ///
1183    /// # Returns
1184    ///
1185    /// The tensor with the given values.
1186    ///
1187    /// # Remarks
1188    ///
1189    /// Uses `arange_step` with a step size of 1 under the hood.
1190    fn int_arange(range: Range<i64>, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
1191        Self::int_arange_step(range, 1, device, dtype)
1192    }
1193
1194    /// Tests if any element in the int `tensor` evaluates to True.
1195    ///
1196    /// # Arguments
1197    ///
1198    /// * `tensor` - The tensor to test.
1199    ///
1200    /// # Returns
1201    ///
1202    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1203    fn int_any(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1204        let int_dtype = tensor.dtype();
1205        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1206        let bool_tensor = B::bool_not(bool_tensor);
1207        let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1208        B::int_greater_elem(sum, 0.into(), out_dtype)
1209    }
1210
1211    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1212    ///
1213    /// # Arguments
1214    ///
1215    /// * `tensor` - The tensor to test.
1216    /// * `dim` - The axis along which to test.
1217    ///
1218    /// # Returns
1219    ///
1220    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1221    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1222    /// evaluates to True, False otherwise.
1223    fn int_any_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1224        let int_dtype = tensor.dtype();
1225        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1226        let bool_tensor = B::bool_not(bool_tensor);
1227        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1228        B::int_greater_elem(sum, 0.into(), out_dtype)
1229    }
1230
1231    /// Tests if all elements in the int `tensor` evaluate to True.
1232    ///
1233    /// # Arguments
1234    ///
1235    /// * `tensor` - The tensor to test.
1236    /// * `out_dtype` - The output tensor dtype.
1237    ///
1238    /// # Returns
1239    ///
1240    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1241    /// evaluate to True, False otherwise.
1242    fn int_all(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1243        let int_dtype = tensor.dtype();
1244        let num_elems = tensor.shape().num_elements() as i64;
1245        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1246        let bool_tensor = B::bool_not(bool_tensor);
1247        let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1248        B::int_equal_elem(sum, num_elems.into(), out_dtype)
1249    }
1250
1251    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1252    ///
1253    /// # Arguments
1254    ///
1255    /// * `tensor` - The tensor to test.
1256    /// * `dim` - The axis along which to test.
1257    /// * `out_dtype` - The output tensor dtype.
1258    ///
1259    /// # Returns
1260    ///
1261    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1262    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1263    /// evaluates to True, False otherwise.
1264    fn int_all_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1265        let int_dtype = tensor.dtype();
1266        let num_elems = tensor.shape()[dim] as i64;
1267        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1268        let bool_tensor = B::bool_not(bool_tensor);
1269        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1270        B::int_equal_elem(sum, num_elems.into(), out_dtype)
1271    }
1272
1273    /// Returns the signs of the int `tensor`.
1274    ///
1275    /// # Arguments
1276    ///
1277    /// * `tensor` - The tensor to extract the signs from.
1278    ///
1279    /// # Returns
1280    ///
1281    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1282    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1283        let dtype = tensor.dtype();
1284        let device = B::int_device(&tensor);
1285        let bool_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
1286        let zeros = B::int_zeros(tensor.shape(), &device, dtype.into());
1287        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.into(), bool_dtype);
1288        let greater_than_zero = B::int_greater_elem(tensor, 0.into(), bool_dtype);
1289
1290        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1).into());
1291        result = B::int_mask_fill(result, greater_than_zero, 1.into());
1292        result
1293    }
1294
1295    /// Broadcasts the int `tensor` to the given `shape`.
1296    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1297
1298    /// Sort the elements of the input `tensor` by value along a given dimension.
1299    ///
1300    /// This sort is unstable (i.e., may reorder equal elements).
1301    ///
1302    /// # Arguments
1303    ///
1304    /// * `tensor` - The input tensor.
1305    /// * `dim` - The axis along which to sort.
1306    /// * `descending` - The sorting order.
1307    ///
1308    /// # Returns
1309    ///
1310    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1311    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1312        sort::<B, Int>(tensor, dim, descending)
1313    }
1314
1315    /// Sort the elements of the input `tensor` by value along a given dimension.
1316    ///
1317    /// This sort is unstable (i.e., may reorder equal elements).
1318    ///
1319    /// # Arguments
1320    ///
1321    /// * `tensor` - The input tensor.
1322    /// * `dim` - The axis along which to sort.
1323    ///
1324    /// # Returns
1325    ///
1326    /// A tensor with the same shape as the input tensor and corresponding indices, where
1327    /// the elements are sorted by value and the indices map back to the original input tensor.
1328    fn int_sort_with_indices(
1329        tensor: IntTensor<B>,
1330        dim: usize,
1331        descending: bool,
1332    ) -> (IntTensor<B>, IntTensor<B>) {
1333        let dtype = tensor.dtype();
1334        sort_with_indices::<B, Int>(tensor, dim, descending, dtype.into())
1335    }
1336
1337    /// Returns the indices that sort the elements of the input `tensor` by value
1338    /// along a given dimension.
1339    ///
1340    /// This sort is unstable (i.e., may reorder equal elements).
1341    ///
1342    /// # Arguments
1343    ///
1344    /// * `tensor` - The input tensor.
1345    /// * `dim` - The axis along which to sort.
1346    /// * `descending` - The sorting order.
1347    ///
1348    /// # Returns
1349    ///
1350    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1351    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1352        let dtype = tensor.dtype();
1353        argsort::<B, Int>(tensor, dim, descending, dtype.into())
1354    }
1355
1356    /// Bitwise AND operation for Int Tensors
1357    fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1358
1359    /// Bitwise AND operation for Int Tensors with a scalar
1360    fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1361
1362    /// Bitwise OR operation for Int Tensors
1363    fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1364
1365    /// Bitwise OR operation for Int Tensors with a scalar
1366    fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1367
1368    /// Bitwise XOR operation for Int Tensors
1369    fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1370
1371    /// Bitwise XOR operation for Int Tensors with a scalar
1372    fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1373
1374    /// Bitwise NOT operation for Int Tensors
1375    fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1376
1377    /// Bitwise left shift operation for Int Tensors
1378    fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1379
1380    /// Bitwise left shift operation for Int Tensors with a scalar
1381    fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1382
1383    /// Bitwise right shift operation for Int Tensors
1384    fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1385
1386    /// Bitwise right shift operation for Int Tensors with a scalar
1387    fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1388
1389    /// Converts a tensor to another integer data type.
1390    ///
1391    /// # Arguments
1392    ///
1393    /// * `tensor` - The tensor to convert.
1394    /// * `dtype` - The target data type.
1395    ///
1396    /// # Returns
1397    ///
1398    /// A tensor with the same values as `tensor` but in the target integer data type.
1399    fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1400
1401    /// Unfold windows along a dimension.
1402    ///
1403    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1404    /// where windows are advanced by `step` at each index.
1405    ///
1406    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1407    ///
1408    /// # Arguments
1409    ///
1410    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1411    /// * `dim` - the selected dim.
1412    /// * `size` - the size of each unfolded window.
1413    /// * `step` - the step between each window.
1414    ///
1415    /// # Returns
1416    ///
1417    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1418    fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1419}