Skip to main content

burn_backend/backend/ops/
int_tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::sort::{argsort, sort, sort_with_indices};
4use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};
5use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};
6use crate::{ExecutionError, Scalar, get_device_settings};
7use alloc::vec::Vec;
8use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16    /// Creates a new int tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    /// * `dtype` - The target data type.
23    ///
24    /// # Returns
25    ///
26    /// The integer tensor with the given shape.
27    fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
28
29    /// Converts the tensor to a data structure.
30    ///
31    /// # Arguments
32    ///
33    /// * `tensor` - The tensor.
34    ///
35    /// # Returns
36    ///
37    /// The data structure with the tensor's data.
38    fn int_into_data(
39        tensor: IntTensor<B>,
40    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
41
42    /// Creates a tensor from the data structure.
43    ///
44    /// # Arguments
45    ///
46    /// * `data` - The data structure.
47    /// * `device` - The device to create the tensor on.
48    ///
49    /// # Returns
50    ///
51    /// The tensor with the data.
52    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
53
54    /// Gets the device of the tensor.
55    ///
56    /// # Arguments
57    ///
58    /// * `tensor` - The tensor.
59    ///
60    /// # Returns
61    ///
62    /// The device of the tensor.
63    fn int_device(tensor: &IntTensor<B>) -> Device<B>;
64
65    /// Moves the tensor to the given device.
66    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
67
68    /// Reshapes the tensor.
69    ///
70    /// # Arguments
71    ///
72    /// * `tensor` - The tensor.
73    /// * `shape` - The new shape.
74    ///
75    /// # Returns
76    ///
77    /// The tensor with the new shape.
78    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
79
80    /// Gets the element at the given indices.
81    ///
82    /// # Arguments
83    ///
84    /// * `tensor` - The tensor.
85    /// * `slices` - The slices specifying ranges and steps for each dimension.
86    ///
87    /// # Returns
88    ///
89    /// The elements at the given indices.
90    ///
91    /// # Note
92    ///
93    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
94    /// be passed to this method. Backend implementations do not need to handle empty slices.
95    fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;
96
97    /// Sets the values in the tensor for the given ranges.
98    ///
99    /// # Arguments
100    ///
101    /// * `tensor` - The tensor.
102    /// * `ranges` - The ranges to set the values for.
103    ///
104    /// # Returns
105    ///
106    /// The tensor with the values set for the given ranges.
107    ///
108    /// # Note
109    ///
110    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
111    /// high-level tensor API and will not be passed to this method. Backend implementations do
112    /// not need to handle empty slice assignments.
113    fn int_slice_assign(
114        tensor: IntTensor<B>,
115        slices: &[Slice],
116        value: IntTensor<B>,
117    ) -> IntTensor<B>;
118
119    /// Converts int tensor to float tensor.
120    ///
121    /// # Arguments
122    ///
123    /// * `tensor` - The tensor.
124    /// * `out_dtype` - The output tensor dtype.
125    ///
126    /// # Returns
127    ///
128    /// The int tensor with the same data as the float tensor.
129    fn int_into_float(tensor: IntTensor<B>, out_dtype: FloatDType) -> FloatTensor<B>;
130
131    /// Fills the tensor with values from the value tensor if the mask is true at the given
132    /// indices.
133    ///
134    /// # Arguments
135    ///
136    /// * `tensor` - The tensor.
137    /// * `mask` - The mask.
138    /// * `value` - The value tensor.
139    ///
140    /// # Returns
141    ///
142    /// The tensor with the values filled.
143    fn int_mask_where(
144        tensor: IntTensor<B>,
145        mask: BoolTensor<B>,
146        value: IntTensor<B>,
147    ) -> IntTensor<B>;
148
149    /// Fills the tensor with the given value if the mask is true at the given indices.
150    ///
151    /// # Arguments
152    ///
153    /// * `tensor` - The tensor.
154    /// * `mask` - The mask.
155    /// * `value` - The value.
156    ///
157    /// # Returns
158    ///
159    /// The tensor with the values filled.
160    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: Scalar) -> IntTensor<B>;
161
162    /// Gather elements from the tensor at the given indices.
163    ///
164    /// # Arguments
165    ///
166    /// * `dim` - The dimension to gather from.
167    /// * `tensor` - The tensor.
168    /// * `indices` - The indices.
169    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
170
171    /// Scatter a given value to the tensor at the given indices using sum reduction.
172    ///
173    /// # Arguments
174    ///
175    /// * `dim` - The dimension to scatter to.
176    /// * `tensor` - The tensor.
177    /// * `indices` - The indices.
178    /// * `value` - The value.
179    ///
180    /// # Returns
181    ///
182    /// The tensor with the values scattered.
183    fn int_scatter_add(
184        dim: usize,
185        tensor: IntTensor<B>,
186        indices: IntTensor<B>,
187        value: IntTensor<B>,
188    ) -> IntTensor<B>;
189
190    /// Multi-dimensional scatter for int tensors.
191    fn int_scatter_nd(
192        _data: IntTensor<B>,
193        _indices: IntTensor<B>,
194        _values: IntTensor<B>,
195        _reduction: crate::tensor::IndexingUpdateOp,
196    ) -> IntTensor<B> {
197        unimplemented!("int_scatter_nd is not implemented for this backend")
198    }
199
200    /// Multi-dimensional gather for int tensors.
201    fn int_gather_nd(_data: IntTensor<B>, _indices: IntTensor<B>) -> IntTensor<B> {
202        unimplemented!("int_gather_nd is not implemented for this backend")
203    }
204
205    /// Select tensor elements along the given dimension corresponding to the given indices.
206    ///
207    /// # Arguments
208    ///
209    /// * `tensor` - The tensor.
210    /// * `dim` - The dimension to select from.
211    /// * `indices` - The indices.
212    ///
213    /// # Returns
214    ///
215    /// The tensor with the selected elements.
216    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
217
218    /// Assign the selected elements along the given dimension corresponding to the given indices
219    /// to the given value using sum reduction.
220    ///
221    /// # Arguments
222    ///
223    /// * `tensor` - The tensor.
224    /// * `dim` - The dimension to select from.
225    /// * `indices` - The indices.
226    /// * `value` - The value.
227    ///
228    /// # Returns
229    ///
230    /// The tensor with the selected elements assigned to the given value.
231    fn int_select_add(
232        tensor: IntTensor<B>,
233        dim: usize,
234        indices: IntTensor<B>,
235        value: IntTensor<B>,
236    ) -> IntTensor<B>;
237
238    /// Repeats the tensor along the given dimension the given number of times.
239    ///
240    /// # Arguments
241    ///
242    /// * `tensor` - The tensor.
243    /// * `dim` - The dimension to repeat.
244    /// * `times` - The number of times to repeat.
245    ///
246    /// # Returns
247    ///
248    /// The tensor with the given dimension repeated the given number of times.
249    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
250        repeat_with_slice_assign::<B, Int>(tensor, dim, times)
251    }
252
253    /// Concatenates the given tensors along the given dimension.
254    ///
255    /// # Arguments
256    ///
257    /// * `tensors` - The tensors.
258    /// * `dim` - The dimension to concatenate along.
259    ///
260    /// # Returns
261    ///
262    /// The concatenated tensor.
263    ///
264    /// # Note
265    ///
266    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
267    /// high-level tensor API and will not be passed to this method. Backend implementations do
268    /// not need to handle empty tensors.
269    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
270        cat_with_slice_assign::<B, Int>(tensors, dim)
271    }
272
273    /// Element-wise equality comparison.
274    ///
275    /// # Arguments
276    ///
277    /// * `lhs` - The left-hand side tensor.
278    /// * `rhs` - The right-hand side tensor.
279    /// * `out_dtype` - The output tensor dtype.
280    ///
281    /// # Returns
282    ///
283    /// The boolean tensor with the result of the comparison.
284    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
285
286    /// Element-wise non-equality comparison.
287    ///
288    /// # Arguments
289    ///
290    /// * `lhs` - The left-hand side tensor.
291    /// * `rhs` - The right-hand side tensor.
292    /// * `out_dtype` - The output tensor dtype.
293    ///
294    /// # Returns
295    ///
296    /// The boolean tensor with the result of the comparison.
297    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
298        let equal_tensor = B::int_equal(lhs, rhs, out_dtype);
299        B::bool_not(equal_tensor)
300    }
301
302    /// Element-wise equality comparison with a scalar.
303    ///
304    /// # Arguments
305    ///
306    /// * `lhs` - The left-hand side tensor.
307    /// * `rhs` - The right-hand side scalar.
308    /// * `out_dtype` - The output tensor dtype.
309    ///
310    /// # Returns
311    ///
312    /// The boolean tensor with the result of the comparison.
313    fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
314
315    /// Element-wise non-equality comparison with a scalar.
316    ///
317    /// # Arguments
318    ///
319    /// * `lhs` - The left-hand side tensor.
320    /// * `rhs` - The right-hand side scalar.
321    /// * `out_dtype` - The output tensor dtype.
322    ///
323    /// # Returns
324    ///
325    /// The boolean tensor with the result of the comparison.
326    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B> {
327        let equal_tensor = B::int_equal_elem(lhs, rhs, out_dtype);
328        B::bool_not(equal_tensor)
329    }
330
331    /// Element-wise greater than comparison.
332    ///
333    /// # Arguments
334    ///
335    /// * `lhs` - The left-hand side tensor.
336    /// * `rhs` - The right-hand side tensor.
337    /// * `out_dtype` - The output tensor dtype.
338    ///
339    /// # Returns
340    ///
341    /// The boolean tensor with the result of the comparison.
342    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
343
344    /// Element-wise greater than comparison with a scalar.
345    ///
346    /// # Arguments
347    ///
348    /// * `lhs` - The left-hand side tensor.
349    /// * `rhs` - The right-hand side scalar.
350    /// * `out_dtype` - The output tensor dtype.
351    ///
352    /// # Returns
353    ///
354    /// The boolean tensor with the result of the comparison.
355    fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
356
357    /// Element-wise greater than or equal comparison.
358    ///
359    /// # Arguments
360    ///
361    /// * `lhs` - The left-hand side tensor.
362    /// * `rhs` - The right-hand side tensor.
363    /// * `out_dtype` - The output tensor dtype.
364    ///
365    /// # Returns
366    ///
367    /// The boolean tensor with the result of the comparison.
368    fn int_greater_equal(
369        lhs: IntTensor<B>,
370        rhs: IntTensor<B>,
371        out_dtype: BoolDType,
372    ) -> BoolTensor<B>;
373
374    /// Element-wise greater than or equal comparison with a scalar.
375    ///
376    /// # Arguments
377    ///
378    /// * `lhs` - The left-hand side tensor.
379    /// * `rhs` - The right-hand side scalar.
380    /// * `out_dtype` - The output tensor dtype.
381    ///
382    /// # Returns
383    ///
384    /// The boolean tensor with the result of the comparison.
385    fn int_greater_equal_elem(
386        lhs: IntTensor<B>,
387        rhs: Scalar,
388        out_dtype: BoolDType,
389    ) -> BoolTensor<B>;
390
391    /// Element-wise less than comparison.
392    ///
393    /// # Arguments
394    ///
395    /// * `lhs` - The left-hand side tensor.
396    /// * `rhs` - The right-hand side tensor.
397    /// * `out_dtype` - The output tensor dtype.
398    ///
399    /// # Returns
400    ///
401    /// The boolean tensor with the result of the comparison.
402    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
403
404    /// Element-wise less than comparison with a scalar.
405    ///
406    /// # Arguments
407    ///
408    /// * `lhs` - The left-hand side tensor.
409    /// * `rhs` - The right-hand side scalar.
410    /// * `out_dtype` - The output tensor dtype.
411    ///
412    /// # Returns
413    ///
414    /// The boolean tensor with the result of the comparison.
415    fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
416
417    /// Element-wise less than or equal comparison.
418    ///
419    /// # Arguments
420    ///
421    /// * `lhs` - The left-hand side tensor.
422    /// * `rhs` - The right-hand side tensor.
423    /// * `out_dtype` - The output tensor dtype.
424    ///
425    /// # Returns
426    ///
427    /// The boolean tensor with the result of the comparison.
428    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType)
429    -> BoolTensor<B>;
430
431    /// Element-wise less than or equal comparison with a scalar.
432    ///
433    /// # Arguments
434    ///
435    /// * `lhs` - The left-hand side tensor.
436    /// * `rhs` - The right-hand side scalar.
437    /// * `out_dtype` - The output tensor dtype.
438    ///
439    /// # Returns
440    ///
441    /// The boolean tensor with the result of the comparison.
442    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
443
444    // ====  NUMERIC ==== //
445
446    /// Element-wise addition.
447    ///
448    /// # Arguments
449    ///
450    /// * `lhs` - The left-hand side tensor.
451    /// * `rhs` - The right-hand side tensor.
452    ///
453    /// # Returns
454    ///
455    /// The result of the addition.
456    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
457
458    /// Element-wise addition with a scalar.
459    ///
460    /// # Arguments
461    ///
462    /// * `lhs` - The left-hand side tensor.
463    /// * `rhs` - The right-hand side scalar.
464    ///
465    /// # Returns
466    ///
467    /// The result of the addition.
468    fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
469
470    /// Element-wise power with a IntTensor.
471    ///
472    /// # Arguments
473    ///
474    /// * `lhs` - The left-hand side IntTensor.
475    /// * `rhs` - The right-hand side IntTensor.
476    ///
477    /// # Returns
478    ///
479    /// The elements of `lhs` raised to the power of the elements of `rhs`.
480    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
481        let dtype = lhs.dtype();
482        let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
483        B::float_into_int(
484            B::float_powi(B::int_into_float(lhs, float_dtype), rhs),
485            dtype.into(),
486        )
487    }
488
489    /// Element-wise power with a scalar.
490    ///
491    /// # Backend Implementors Note
492    ///
493    /// A number of common exponent cases can be implemented with operations
494    /// which are much cheaper than generic exponentiation.
495    ///
496    /// This (`Backend` impl overridable) operation handles generic optimizations
497    /// for several common integer exponent cases; and then dispatches to
498    /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
499    /// operation to handle the generic case.
500    ///
501    /// # Arguments
502    ///
503    /// * `lhs` - The left-hand side tensor.
504    /// * `rhs` - The right-hand side scalar.
505    ///
506    /// # Returns
507    ///
508    /// The elements of `lhs` raised to the value of `rhs`.
509    fn int_powi_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
510        let exp = rhs.elem::<i32>();
511        match exp {
512            0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
513            1 => lhs,
514            2 => Self::int_mul(lhs.clone(), lhs),
515            _ => Self::int_powi_scalar_impl(lhs, rhs),
516        }
517    }
518
519    /// Element-wise power with a scalar.
520    ///
521    /// # Backend Implementors Note
522    ///
523    /// This is the generic implementation of integer exponentiation
524    /// called by [`Self::int_powi_scalar`] in the fallback case.
525    ///
526    /// By default, this performs a relatively expensive conversion to float,
527    /// exponentiation in float, and conversion back to int.
528    /// This reduces the minimal operation set for `Backend`s,
529    /// at the cost of performance.
530    ///
531    /// This is a good target for specialized optimizations in `Backend` implementations.
532    ///
533    /// As a general rule, this should not be called directly.
534    ///
535    /// # Arguments
536    ///
537    /// * `lhs` - The left-hand side tensor.
538    /// * `rhs` - The right-hand side scalar.
539    ///
540    /// # Returns
541    ///
542    /// The elements of `lhs` raised to the value of `rhs`.
543    fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
544        let dtype = lhs.dtype();
545        let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
546        B::float_into_int(
547            B::float_powi_scalar_impl(B::int_into_float(lhs, float_dtype), rhs),
548            dtype.into(),
549        )
550    }
551
552    /// Clamps a tensor under a minimum value.
553    ///
554    /// # Arguments
555    ///
556    /// * `tensor` - The tensor to clamp.
557    /// * `min` - The minimum value.
558    ///
559    /// # Returns
560    ///
561    /// The clamped tensor.
562    fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
563        let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
564        let mask = Self::int_lower_elem(tensor.clone(), min, dtype);
565        Self::int_mask_fill(tensor, mask, min)
566    }
567
568    /// Clamps a tensor over a maximum value.
569    ///
570    /// # Arguments
571    ///
572    /// * `tensor` - The tensor to clamp.
573    /// * `max` - The maximum value.
574    ///
575    /// # Returns
576    ///
577    /// The clamped tensor.
578    fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
579        let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
580        let mask = Self::int_greater_elem(tensor.clone(), max, dtype);
581        Self::int_mask_fill(tensor, mask, max)
582    }
583
584    /// Clamps a tensor between a minimum and maximum value.
585    ///
586    /// # Arguments
587    ///
588    /// * `tensor` - The tensor to clamp.
589    /// * `min` - The minimum value.
590    /// * `max` - The maximum value.
591    ///
592    /// # Returns
593    ///
594    /// The clamped tensor.
595    fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
596        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
597    }
598
599    /// Element-wise subtraction.
600    ///
601    /// # Arguments
602    ///
603    /// * `lhs` - The left-hand side tensor.
604    /// * `rhs` - The right-hand side tensor.
605    ///
606    /// # Returns
607    ///
608    /// The result of the subtraction.
609    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
610
611    /// Element-wise subtraction with a scalar.
612    ///
613    /// # Arguments
614    ///
615    /// * `lhs` - The left-hand side tensor.
616    /// * `rhs` - The right-hand side scalar.
617    ///
618    /// # Returns
619    ///
620    /// The result of the subtraction.
621    fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
622
623    /// Element-wise multiplication.
624    ///
625    /// # Arguments
626    ///
627    /// * `lhs` - The left-hand side tensor.
628    /// * `rhs` - The right-hand side tensor.
629    ///
630    /// # Returns
631    ///
632    /// The result of the multiplication.
633    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
634
635    /// Element-wise multiplication with a scalar.
636    ///
637    /// # Arguments
638    ///
639    /// * `lhs` - The left-hand side tensor.
640    /// * `rhs` - The right-hand side scalar.
641    ///
642    /// # Returns
643    ///
644    /// The result of the multiplication.
645    fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
646
647    /// Element-wise division.
648    ///
649    /// # Arguments
650    ///
651    /// * `lhs` - The left-hand side tensor.
652    /// * `rhs` - The right-hand side tensor.
653    ///
654    /// # Returns
655    ///
656    /// The result of the division.
657    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
658
659    /// Element-wise division with a scalar.
660    ///
661    /// # Arguments
662    ///
663    /// * `lhs` - The left-hand side tensor.
664    /// * `rhs` - The right-hand side scalar.
665    ///
666    /// # Returns
667    ///
668    /// The result of the division.
669    fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
670
671    /// Element-wise modulus.
672    ///
673    /// # Arguments
674    /// * `lhs` - The left-hand side tensor.
675    /// * `rhs` - The right-hand side scalar.
676    ///
677    /// # Returns
678    ///
679    /// The result of applying the modulus of the scalar to the tensor.
680    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
681
682    /// Element-wise modulus with a scalar.
683    ///
684    /// # Arguments
685    /// * `lhs` - The left-hand side tensor.
686    /// * `rhs` - The right-hand side scalar.
687    ///
688    /// # Returns
689    ///
690    /// The result of applying the modulus of the scalar to the tensor.
691    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
692
693    /// Multiplies two tensors together using matrix multiplication.
694    ///
695    /// # Arguments
696    ///
697    /// * `lhs` - The left-hand side tensor.
698    /// * `rhs` - The right-hand side tensor.
699    ///
700    /// # Returns
701    ///
702    /// The result of multiplying the two tensors together using matrix multiplication.
703    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
704
705    /// Element-wise negation.
706    ///
707    /// # Arguments
708    ///
709    /// * `tensor` - The tensor to negate.
710    ///
711    /// # Returns
712    ///
713    /// The negated tensor.
714    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
715        Self::int_mul_scalar(tensor, (-1).into())
716    }
717
718    /// Creates a tensor of zeros.
719    ///
720    /// # Arguments
721    ///
722    /// * `shape` - The shape of the tensor.
723    /// * `device` - The device to create the tensor on.
724    /// * `dtype` - The target data type.
725    ///
726    /// # Returns
727    ///
728    /// The tensor of zeros.
729    fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
730        Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
731    }
732
733    /// Creates a tensor of ones.
734    ///
735    /// # Arguments
736    ///
737    /// * `shape` - The shape of the tensor.
738    /// * `device` - The device to create the tensor on.
739    /// * `dtype` - The target data type.
740    ///
741    /// # Returns
742    ///
743    /// The tensor of ones.
744    fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
745        Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
746    }
747
748    /// Creates a tensor filled with given value.
749    ///
750    /// # Arguments
751    ///
752    /// * `shape` - The shape of the tensor.
753    /// * `fill_value` - The value with which to fill the tensor.
754    /// * `device` - The device to create the tensor on.
755    /// * `dtype` - The target data type.
756    ///
757    /// # Returns
758    ///
759    /// The tensor filled with given value
760    fn int_full(
761        shape: Shape,
762        fill_value: Scalar,
763        device: &Device<B>,
764        dtype: IntDType,
765    ) -> IntTensor<B> {
766        Self::int_from_data(
767            TensorData::full_dtype(shape, fill_value, dtype.into()),
768            device,
769        )
770    }
771
772    /// Sums all elements in the tensor.
773    ///
774    /// # Arguments
775    ///
776    /// * `tensor` - The tensor to sum.
777    ///
778    /// # Returns
779    ///
780    /// The sum of all elements in the tensor.
781    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
782
783    /// Sums all elements in the tensor along a dimension.
784    ///
785    /// # Arguments
786    ///
787    /// * `tensor` - The tensor to sum.
788    /// * `dim` - The dimension to sum along.
789    ///
790    /// # Returns
791    ///
792    /// The sum of all elements in the tensor along the dimension.
793    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
794
795    /// Computes the product of all elements in the tensor.
796    ///
797    /// # Arguments
798    ///
799    /// * `tensor` - The tensor to compute the product of.
800    ///
801    /// # Returns
802    ///
803    /// The product of all elements in the tensor.
804    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
805
806    /// Computes the product of all elements in the tensor along a dimension.
807    ///
808    /// # Arguments
809    ///
810    /// * `tensor` - The tensor to compute the product of.
811    /// * `dim` - The dimension to compute the product along.
812    ///
813    /// # Returns
814    ///
815    /// The product of all elements in the tensor along the dimension.
816    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
817
818    /// Computes the mean of all elements in the tensor.
819    ///
820    /// # Arguments
821    ///
822    /// * `tensor` - The tensor to compute the mean of.
823    ///
824    /// # Returns
825    ///
826    /// The mean of all elements in the tensor.
827    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
828        let num_elems = tensor.shape().num_elements() as i64;
829        B::int_div_scalar(B::int_sum(tensor), num_elems.into())
830    }
831
832    /// Computes the mean of all elements in the tensor along a dimension.
833    ///
834    /// # Arguments
835    ///
836    /// * `tensor` - The tensor to compute the mean of.
837    ///
838    /// # Returns
839    ///
840    /// The mean of all elements in the tensor along the dimension.
841    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
842
843    /// Computes the cumulative sum of elements along a dimension.
844    ///
845    /// # Arguments
846    ///
847    /// * `tensor` - The tensor to compute the cumulative sum of.
848    /// * `dim` - The dimension along which to compute the cumulative sum.
849    ///
850    /// # Returns
851    ///
852    /// A tensor with the same shape where each element is the cumulative sum
853    /// of all elements up to and including that position along the dimension.
854    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
855
856    /// Computes the cumulative product of elements along a dimension.
857    ///
858    /// # Arguments
859    ///
860    /// * `tensor` - The tensor to compute the cumulative product of.
861    /// * `dim` - The dimension along which to compute the cumulative product.
862    ///
863    /// # Returns
864    ///
865    /// A tensor with the same shape where each element is the cumulative product
866    /// of all elements up to and including that position along the dimension.
867    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
868
869    /// Computes the cumulative minimum of elements along a dimension.
870    ///
871    /// # Arguments
872    ///
873    /// * `tensor` - The tensor to compute the cumulative minimum of.
874    /// * `dim` - The dimension along which to compute the cumulative minimum.
875    ///
876    /// # Returns
877    ///
878    /// A tensor with the same shape where each element is the minimum
879    /// of all elements up to and including that position along the dimension.
880    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
881
882    /// Computes the cumulative maximum of elements along a dimension.
883    ///
884    /// # Arguments
885    ///
886    /// * `tensor` - The tensor to compute the cumulative maximum of.
887    /// * `dim` - The dimension along which to compute the cumulative maximum.
888    ///
889    /// # Returns
890    ///
891    /// A tensor with the same shape where each element is the maximum
892    /// of all elements up to and including that position along the dimension.
893    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
894
895    /// Gets the indices of the maximum elements along a dimension.
896    ///
897    /// # Arguments
898    ///
899    /// * `tensor` - The tensor to get the maximum indices of.
900    /// * `dim` - The dimension to get the maximum indices along.
901    ///
902    /// # Returns
903    ///
904    /// The indices of the maximum elements along the dimension.
905    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
906
907    /// Gets the indices of the k maximum elements along a dimension.
908    /// If two elements share the same value, it will be ordered by the lowest
909    /// coordinate
910    ///
911    /// # Arguments
912    ///
913    /// * `tensor` - The tensor to get the maximum indices of.
914    /// * `dim` - The dimension to get the maximum indices along.
915    /// * `k` - number of maximum elements.
916    ///
917    /// # Returns
918    ///
919    /// The indices of the maximum elements along the dimension.
920    fn int_argtopk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B>;
921
922    /// Gets the values of the k maximum elements along a dimension.
923    /// # Arguments
924    ///
925    /// * `tensor` - The tensor to get the maximum values of.
926    /// * `dim` - The dimension to get the maximum values along.
927    /// * `k` - number of maximum elements.
928    ///
929    /// # Returns
930    ///
931    /// The values of the maximum elements along the dimension.
932    fn int_topk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B> {
933        let device = Self::int_device(&tensor);
934        let dtype = get_device_settings::<B>(&device).int_dtype;
935        let k_indices = Self::int_arange(0..k as i64, &device, dtype);
936        Self::int_select(Self::int_sort(tensor, dim, true), dim, k_indices)
937    }
938
939    /// Gets the indices of the minimum elements along a dimension.
940    ///
941    /// # Arguments
942    ///
943    /// * `tensor` - The tensor to get the minimum indices of.
944    /// * `dim` - The dimension to get the minimum indices along.
945    ///
946    /// # Returns
947    ///
948    /// The indices of the minimum elements along the dimension.
949    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
950
951    /// Gets the maximum element in the tensor.
952    ///
953    /// # Arguments
954    ///
955    /// * `tensor` - The tensor to get the maximum element of.
956    ///
957    /// # Returns
958    ///
959    /// The maximum element in the tensor.
960    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
961        let shape = tensor.shape();
962        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
963
964        B::int_max_dim(tensor, 0)
965    }
966
967    /// Gets the maximum element in the tensor along a dimension.
968    ///
969    /// # Arguments
970    ///
971    /// * `tensor` - The tensor to get the maximum element of.
972    /// * `dim` - The dimension to get the maximum element along.
973    ///
974    /// # Returns
975    ///
976    /// The maximum element in the tensor along the dimension.
977    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
978        let index = B::int_argmax(tensor.clone(), dim);
979        B::int_gather(dim, tensor, index)
980    }
981
982    /// Gets the maximum elements and corresponding indices along a dimension.
983    ///
984    /// # Arguments
985    ///
986    /// * `tensor` - The tensor to get the maximum elements and indices of.
987    /// * `dim` - The dimension to get the maximum elements and indices along.
988    ///
989    /// # Returns
990    ///
991    /// The maximum elements and corresponding indices along the dimension.
992    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
993        let index = B::int_argmax(tensor.clone(), dim);
994        let values = B::int_gather(dim, tensor, index.clone());
995
996        (values, index)
997    }
998
999    /// Gets the maximum absolute element in the tensor.
1000    ///
1001    /// # Arguments
1002    ///
1003    /// * `tensor` - The tensor to get the maximum element of.
1004    ///
1005    /// # Returns
1006    ///
1007    /// The maximum element in the tensor.
1008    fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
1009        let shape = tensor.shape();
1010        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1011
1012        B::int_max_abs_dim(tensor, 0)
1013    }
1014
1015    /// Gets the maximum absolute element in the tensor along a dimension.
1016    ///
1017    /// # Arguments
1018    ///
1019    /// * `tensor` - The tensor to get the maximum element of.
1020    /// * `dim` - The dimension to get the maximum element along.
1021    ///
1022    /// # Returns
1023    ///
1024    /// The maximum element in the tensor along the dimension.
1025    fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1026        B::int_max_dim(B::int_abs(tensor), dim)
1027    }
1028
1029    /// Gets the minimum element in the tensor.
1030    ///
1031    /// # Arguments
1032    ///
1033    /// * `tensor` - The tensor to get the minimum element of.
1034    ///
1035    /// # Returns
1036    ///
1037    /// The minimum element in the tensor.
1038    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
1039        let shape = tensor.shape();
1040        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1041
1042        B::int_min_dim(tensor, 0)
1043    }
1044
1045    /// Gets the minimum elements in the tensor along a dimension.
1046    ///
1047    /// # Arguments
1048    ///
1049    /// * `tensor` - The tensor to get the minimum element of.
1050    /// * `dim` - The dimension to get the minimum element along.
1051    ///
1052    /// # Returns
1053    ///
1054    /// The minimum element in the tensor along the dimension.
1055    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1056        let index = B::int_argmin(tensor.clone(), dim);
1057        B::int_gather(dim, tensor, index)
1058    }
1059
1060    /// Gets the minimum elements and corresponding indices along a dimension.
1061    ///
1062    /// # Arguments
1063    ///
1064    /// * `tensor` - The tensor to get the minimum elements and indices of.
1065    /// * `dim` - The dimension to get the minimum elements and indices along.
1066    ///
1067    /// # Returns
1068    ///
1069    /// The minimum elements and corresponding indices along the dimension.
1070    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1071        let indices = B::int_argmin(tensor.clone(), dim);
1072        let values = B::int_gather(dim, tensor, indices.clone());
1073
1074        (values, indices)
1075    }
1076
1077    /// Returns a new tensor with absolute values.
1078    ///
1079    /// # Arguments
1080    ///
1081    /// * `tensor` - The tensor to take absolute value of.
1082    ///
1083    /// # Returns
1084    ///
1085    /// A tensor with the same shape as `tensor` with absolute values.
1086    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1087
1088    /// Transposes an int tensor.
1089    ///
1090    /// # Arguments
1091    ///
1092    /// * `tensor` - The tensor to transpose.
1093    ///
1094    /// # Returns
1095    ///
1096    /// The transposed tensor.
1097    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1098        let ndims = tensor.shape().num_dims();
1099        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1100    }
1101
1102    /// Swaps two dimensions of an int tensor.
1103    ///
1104    /// # Arguments
1105    ///
1106    /// * `tensor` - The tensor to swap the dimensions of.
1107    /// * `dim1` - The first dimension to swap.
1108    /// * `dim2` - The second dimension to swap.
1109    ///
1110    /// # Returns
1111    ///
1112    /// The tensor with the dimensions swapped.
1113    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1114
1115    /// Permutes the dimensions of a tensor.
1116    ///
1117    /// # Arguments
1118    ///
1119    /// * `tensor` - The tensor to permute the dimensions of.
1120    /// * `axes` - The new order of the dimensions.
1121    /// # Returns
1122    ///
1123    /// The tensor with the dimensions permuted.
1124    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1125
1126    /// Reverse the order of elements in a tensor along the given axes.
1127    ///
1128    /// # Arguments
1129    ///
1130    /// * `tensor` - The tensor to reverse.
1131    /// * `axes` - The axes to reverse.
1132    ///
1133    /// The tensor with the elements reversed.
1134    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1135
1136    /// Creates a new int tensor with random values.
1137    ///
1138    ///  # Arguments
1139    ///  * `shape` - The shape of the tensor.
1140    ///  * `distribution` - The distribution to sample from.
1141    ///  * `device` - The device to create the tensor on.
1142    /// * `dtype` - The target data type.
1143    ///
1144    ///  # Returns
1145    ///
1146    ///  The tensor with the given shape and random values.
1147    fn int_random(
1148        shape: Shape,
1149        distribution: Distribution,
1150        device: &Device<B>,
1151        dtype: IntDType,
1152    ) -> IntTensor<B>;
1153
1154    /// Creates a new tensor with values from the given range with the given step size.
1155    ///
1156    /// # Arguments
1157    ///
1158    /// * `range` - The range of values.
1159    /// * `step` - The step size.
1160    /// * `device` - The device to create the tensor on.
1161    /// * `dtype` - The target data type.
1162    ///
1163    /// # Returns
1164    ///
1165    /// The tensor with the given values.
1166    fn int_arange_step(
1167        range: Range<i64>,
1168        step: usize,
1169        device: &Device<B>,
1170        dtype: IntDType,
1171    ) -> IntTensor<B> {
1172        let value = range
1173            .step_by(step)
1174            .map(|i| i.elem())
1175            .collect::<Vec<IntElem<B>>>();
1176        let shape = Shape::new([value.len()]);
1177        let data = TensorData::new(value, shape).convert_dtype(dtype.into());
1178        B::int_from_data(data, device)
1179    }
1180
1181    /// Creates a new tensor with values from the given range.
1182    ///
1183    /// # Arguments
1184    ///
1185    /// * `range` - The range of values.
1186    /// * `device` - The device to create the tensor on.
1187    ///
1188    /// # Returns
1189    ///
1190    /// The tensor with the given values.
1191    ///
1192    /// # Remarks
1193    ///
1194    /// Uses `arange_step` with a step size of 1 under the hood.
1195    fn int_arange(range: Range<i64>, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
1196        Self::int_arange_step(range, 1, device, dtype)
1197    }
1198
1199    /// Tests if any element in the int `tensor` evaluates to True.
1200    ///
1201    /// # Arguments
1202    ///
1203    /// * `tensor` - The tensor to test.
1204    ///
1205    /// # Returns
1206    ///
1207    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1208    fn int_any(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1209        let int_dtype = tensor.dtype();
1210        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1211        let bool_tensor = B::bool_not(bool_tensor);
1212        let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1213        B::int_greater_elem(sum, 0.into(), out_dtype)
1214    }
1215
1216    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1217    ///
1218    /// # Arguments
1219    ///
1220    /// * `tensor` - The tensor to test.
1221    /// * `dim` - The axis along which to test.
1222    ///
1223    /// # Returns
1224    ///
1225    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1226    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1227    /// evaluates to True, False otherwise.
1228    fn int_any_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1229        let int_dtype = tensor.dtype();
1230        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1231        let bool_tensor = B::bool_not(bool_tensor);
1232        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1233        B::int_greater_elem(sum, 0.into(), out_dtype)
1234    }
1235
1236    /// Tests if all elements in the int `tensor` evaluate to True.
1237    ///
1238    /// # Arguments
1239    ///
1240    /// * `tensor` - The tensor to test.
1241    /// * `out_dtype` - The output tensor dtype.
1242    ///
1243    /// # Returns
1244    ///
1245    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1246    /// evaluate to True, False otherwise.
1247    fn int_all(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1248        let int_dtype = tensor.dtype();
1249        let num_elems = tensor.shape().num_elements() as i64;
1250        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1251        let bool_tensor = B::bool_not(bool_tensor);
1252        let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1253        B::int_equal_elem(sum, num_elems.into(), out_dtype)
1254    }
1255
1256    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1257    ///
1258    /// # Arguments
1259    ///
1260    /// * `tensor` - The tensor to test.
1261    /// * `dim` - The axis along which to test.
1262    /// * `out_dtype` - The output tensor dtype.
1263    ///
1264    /// # Returns
1265    ///
1266    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1267    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1268    /// evaluates to True, False otherwise.
1269    fn int_all_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1270        let int_dtype = tensor.dtype();
1271        let num_elems = tensor.shape()[dim] as i64;
1272        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1273        let bool_tensor = B::bool_not(bool_tensor);
1274        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1275        B::int_equal_elem(sum, num_elems.into(), out_dtype)
1276    }
1277
1278    /// Returns the signs of the int `tensor`.
1279    ///
1280    /// # Arguments
1281    ///
1282    /// * `tensor` - The tensor to extract the signs from.
1283    ///
1284    /// # Returns
1285    ///
1286    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1287    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1288        let dtype = tensor.dtype();
1289        let device = B::int_device(&tensor);
1290        let bool_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
1291        let zeros = B::int_zeros(tensor.shape(), &device, dtype.into());
1292        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.into(), bool_dtype);
1293        let greater_than_zero = B::int_greater_elem(tensor, 0.into(), bool_dtype);
1294
1295        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1).into());
1296        result = B::int_mask_fill(result, greater_than_zero, 1.into());
1297        result
1298    }
1299
1300    /// Broadcasts the int `tensor` to the given `shape`.
1301    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1302
1303    /// Sort the elements of the input `tensor` by value along a given dimension.
1304    ///
1305    /// This sort is unstable (i.e., may reorder equal elements).
1306    ///
1307    /// # Arguments
1308    ///
1309    /// * `tensor` - The input tensor.
1310    /// * `dim` - The axis along which to sort.
1311    /// * `descending` - The sorting order.
1312    ///
1313    /// # Returns
1314    ///
1315    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1316    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1317        sort::<B, Int>(tensor, dim, descending)
1318    }
1319
1320    /// Sort the elements of the input `tensor` by value along a given dimension.
1321    ///
1322    /// This sort is unstable (i.e., may reorder equal elements).
1323    ///
1324    /// # Arguments
1325    ///
1326    /// * `tensor` - The input tensor.
1327    /// * `dim` - The axis along which to sort.
1328    ///
1329    /// # Returns
1330    ///
1331    /// A tensor with the same shape as the input tensor and corresponding indices, where
1332    /// the elements are sorted by value and the indices map back to the original input tensor.
1333    fn int_sort_with_indices(
1334        tensor: IntTensor<B>,
1335        dim: usize,
1336        descending: bool,
1337    ) -> (IntTensor<B>, IntTensor<B>) {
1338        let dtype = tensor.dtype();
1339        sort_with_indices::<B, Int>(tensor, dim, descending, dtype.into())
1340    }
1341
1342    /// Returns the indices that sort the elements of the input `tensor` by value
1343    /// along a given dimension.
1344    ///
1345    /// This sort is unstable (i.e., may reorder equal elements).
1346    ///
1347    /// # Arguments
1348    ///
1349    /// * `tensor` - The input tensor.
1350    /// * `dim` - The axis along which to sort.
1351    /// * `descending` - The sorting order.
1352    ///
1353    /// # Returns
1354    ///
1355    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1356    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1357        let dtype = tensor.dtype();
1358        argsort::<B, Int>(tensor, dim, descending, dtype.into())
1359    }
1360
1361    /// Bitwise AND operation for Int Tensors
1362    fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1363
1364    /// Bitwise AND operation for Int Tensors with a scalar
1365    fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1366
1367    /// Bitwise OR operation for Int Tensors
1368    fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1369
1370    /// Bitwise OR operation for Int Tensors with a scalar
1371    fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1372
1373    /// Bitwise XOR operation for Int Tensors
1374    fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1375
1376    /// Bitwise XOR operation for Int Tensors with a scalar
1377    fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1378
1379    /// Bitwise NOT operation for Int Tensors
1380    fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1381
1382    /// Bitwise left shift operation for Int Tensors
1383    fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1384
1385    /// Bitwise left shift operation for Int Tensors with a scalar
1386    fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1387
1388    /// Bitwise right shift operation for Int Tensors
1389    fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1390
1391    /// Bitwise right shift operation for Int Tensors with a scalar
1392    fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1393
1394    /// Converts a tensor to another integer data type.
1395    ///
1396    /// # Arguments
1397    ///
1398    /// * `tensor` - The tensor to convert.
1399    /// * `dtype` - The target data type.
1400    ///
1401    /// # Returns
1402    ///
1403    /// A tensor with the same values as `tensor` but in the target integer data type.
1404    fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1405
1406    /// Unfold windows along a dimension.
1407    ///
1408    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1409    /// where windows are advanced by `step` at each index.
1410    ///
1411    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1412    ///
1413    /// # Arguments
1414    ///
1415    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1416    /// * `dim` - the selected dim.
1417    /// * `size` - the size of each unfolded window.
1418    /// * `step` - the step between each window.
1419    ///
1420    /// # Returns
1421    ///
1422    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1423    fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1424}