burn_tensor/tensor/ops/
int_tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
4use crate::{
5    Distribution, ElementConversion, Int, IntDType, TensorData, backend::Backend, tensor::Shape,
6};
7use crate::{TensorMetadata, argsort, sort, sort_with_indices};
8use alloc::vec::Vec;
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
12/// for documentation on each function.
13pub trait IntTensorOps<B: Backend> {
14    /// Creates a new int tensor.
15    ///
16    /// # Arguments
17    ///
18    /// * `shape` - The shape of the tensor.
19    /// * `device` - The device to create the tensor on.
20    /// * `dtype` - The target data type.
21    ///
22    /// # Returns
23    ///
24    /// The integer tensor with the given shape.
25    fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
26
27    /// Converts the tensor to a data structure.
28    ///
29    /// # Arguments
30    ///
31    /// * `tensor` - The tensor.
32    ///
33    /// # Returns
34    ///
35    /// The data structure with the tensor's data.
36    fn int_into_data(tensor: IntTensor<B>) -> impl Future<Output = TensorData> + Send;
37
38    /// Creates a tensor from the data structure.
39    ///
40    /// # Arguments
41    ///
42    /// * `data` - The data structure.
43    /// * `device` - The device to create the tensor on.
44    ///
45    /// # Returns
46    ///
47    /// The tensor with the data.
48    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
49
50    /// Gets the device of the tensor.
51    ///
52    /// # Arguments
53    ///
54    /// * `tensor` - The tensor.
55    ///
56    /// # Returns
57    ///
58    /// The device of the tensor.
59    fn int_device(tensor: &IntTensor<B>) -> Device<B>;
60
61    /// Moves the tensor to the given device.
62    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
63
64    /// Reshapes the tensor.
65    ///
66    /// # Arguments
67    ///
68    /// * `tensor` - The tensor.
69    /// * `shape` - The new shape.
70    ///
71    /// # Returns
72    ///
73    /// The tensor with the new shape.
74    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
75
76    /// Gets the element at the given indices.
77    ///
78    /// # Arguments
79    ///
80    /// * `tensor` - The tensor.
81    /// * `slices` - The slices specifying ranges and steps for each dimension.
82    ///
83    /// # Returns
84    ///
85    /// The elements at the given indices.
86    fn int_slice(tensor: IntTensor<B>, slices: &[crate::Slice]) -> IntTensor<B>;
87
88    /// Sets the values in the tensor for the given ranges.
89    ///
90    /// # Arguments
91    ///
92    /// * `tensor` - The tensor.
93    /// * `ranges` - The ranges to set the values for.
94    ///
95    /// # Returns
96    ///
97    /// The tensor with the values set for the given ranges.
98    fn int_slice_assign(
99        tensor: IntTensor<B>,
100        slices: &[crate::Slice],
101        value: IntTensor<B>,
102    ) -> IntTensor<B>;
103
104    /// Converts int tensor to float tensor.
105    ///
106    /// # Arguments
107    ///
108    /// * `tensor` - The tensor.
109    ///
110    /// # Returns
111    ///
112    /// The int tensor with the same data as the float tensor.
113    fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
114
115    /// Fills the tensor with values from the source tensor if the mask is true at the given
116    /// indices.
117    ///
118    /// # Arguments
119    ///
120    /// * `tensor` - The tensor.
121    /// * `mask` - The mask.
122    /// * `source` - The source tensor.
123    ///
124    /// # Returns
125    ///
126    /// The tensor with the values filled.
127    fn int_mask_where(
128        tensor: IntTensor<B>,
129        mask: BoolTensor<B>,
130        source: IntTensor<B>,
131    ) -> IntTensor<B>;
132
133    /// Fills the tensor with the given value if the mask is true at the given indices.
134    ///
135    /// # Arguments
136    ///
137    /// * `tensor` - The tensor.
138    /// * `mask` - The mask.
139    /// * `value` - The value.
140    ///
141    /// # Returns
142    ///
143    /// The tensor with the values filled.
144    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
145
146    /// Gather elements from the tensor at the given indices.
147    ///
148    /// # Arguments
149    ///
150    /// * `dim` - The dimension to gather from.
151    /// * `tensor` - The tensor.
152    /// * `indices` - The indices.
153    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
154
155    /// Scatter a given value to the tensor at the given indices.
156    ///
157    /// # Arguments
158    ///
159    /// * `dim` - The dimension to scatter to.
160    /// * `tensor` - The tensor.
161    /// * `indices` - The indices.
162    /// * `value` - The value.
163    ///
164    /// # Returns
165    ///
166    /// The tensor with the values scattered.
167    fn int_scatter(
168        dim: usize,
169        tensor: IntTensor<B>,
170        indices: IntTensor<B>,
171        value: IntTensor<B>,
172    ) -> IntTensor<B>;
173
174    /// Select tensor elements along the given dimension corresponding to the given indices.
175    ///
176    /// # Arguments
177    ///
178    /// * `tensor` - The tensor.
179    /// * `dim` - The dimension to select from.
180    /// * `indices` - The indices.
181    ///
182    /// # Returns
183    ///
184    /// The tensor with the selected elements.
185    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
186
187    /// Assign the selected elements along the given dimension corresponding to the given indices
188    /// to the given value.
189    ///
190    /// # Arguments
191    ///
192    /// * `tensor` - The tensor.
193    /// * `dim` - The dimension to select from.
194    /// * `indices` - The indices.
195    /// * `value` - The value.
196    ///
197    /// # Returns
198    ///
199    /// The tensor with the selected elements assigned to the given value.
200    fn int_select_assign(
201        tensor: IntTensor<B>,
202        dim: usize,
203        indices: IntTensor<B>,
204        value: IntTensor<B>,
205    ) -> IntTensor<B>;
206
207    /// Repeats the tensor along the given dimension the given number of times.
208    ///
209    /// # Arguments
210    ///
211    /// * `tensor` - The tensor.
212    /// * `dim` - The dimension to repeat.
213    /// * `times` - The number of times to repeat.
214    ///
215    /// # Returns
216    ///
217    /// The tensor with the given dimension repeated the given number of times.
218    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
219        repeat_with_slice_assign::<B, Int>(tensor, dim, times)
220    }
221
222    /// Concatenates the given tensors along the given dimension.
223    ///
224    /// # Arguments
225    ///
226    /// * `tensors` - The tensors.
227    /// * `dim` - The dimension to concatenate along.
228    ///
229    /// # Returns
230    ///
231    /// The concatenated tensor.
232    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
233        cat_with_slice_assign::<B, Int>(tensors, dim)
234    }
235
236    /// Element-wise equality comparison.
237    ///
238    /// # Arguments
239    ///
240    /// * `lhs` - The left-hand side tensor.
241    /// * `rhs` - The right-hand side tensor.
242    ///
243    /// # Returns
244    ///
245    /// The boolean tensor with the result of the comparison.
246    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
247
248    /// Element-wise non-equality comparison.
249    ///
250    /// # Arguments
251    ///
252    /// * `lhs` - The left-hand side tensor.
253    /// * `rhs` - The right-hand side tensor.
254    ///
255    /// # Returns
256    ///
257    /// The boolean tensor with the result of the comparison.
258    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
259        let equal_tensor = B::int_equal(lhs, rhs);
260        B::bool_not(equal_tensor)
261    }
262
263    /// Element-wise equality comparison with a scalar.
264    ///
265    /// # Arguments
266    ///
267    /// * `lhs` - The left-hand side tensor.
268    /// * `rhs` - The right-hand side scalar.
269    ///
270    /// # Returns
271    ///
272    /// The boolean tensor with the result of the comparison.
273    fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
274
275    /// Element-wise non-equality comparison with a scalar.
276    ///
277    /// # Arguments
278    ///
279    /// * `lhs` - The left-hand side tensor.
280    /// * `rhs` - The right-hand side scalar.
281    ///
282    /// # Returns
283    ///
284    /// The boolean tensor with the result of the comparison.
285    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
286        let equal_tensor = B::int_equal_elem(lhs, rhs);
287        B::bool_not(equal_tensor)
288    }
289
290    /// Element-wise greater than comparison.
291    ///
292    /// # Arguments
293    ///
294    /// * `lhs` - The left-hand side tensor.
295    /// * `rhs` - The right-hand side tensor.
296    ///
297    /// # Returns
298    ///
299    /// The boolean tensor with the result of the comparison.
300    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
301
302    /// Element-wise greater than comparison with a scalar.
303    ///
304    /// # Arguments
305    ///
306    /// * `lhs` - The left-hand side tensor.
307    /// * `rhs` - The right-hand side scalar.
308    ///
309    /// # Returns
310    ///
311    /// The boolean tensor with the result of the comparison.
312    fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
313
314    /// Element-wise greater than or equal comparison.
315    ///
316    /// # Arguments
317    ///
318    /// * `lhs` - The left-hand side tensor.
319    /// * `rhs` - The right-hand side tensor.
320    ///
321    /// # Returns
322    ///
323    /// The boolean tensor with the result of the comparison.
324    fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
325
326    /// Element-wise greater than or equal comparison with a scalar.
327    ///
328    /// # Arguments
329    ///
330    /// * `lhs` - The left-hand side tensor.
331    /// * `rhs` - The right-hand side scalar.
332    ///
333    /// # Returns
334    ///
335    /// The boolean tensor with the result of the comparison.
336    fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
337
338    /// Element-wise less than comparison.
339    ///
340    /// # Arguments
341    ///
342    /// * `lhs` - The left-hand side tensor.
343    /// * `rhs` - The right-hand side tensor.
344    ///
345    /// # Returns
346    ///
347    /// The boolean tensor with the result of the comparison.
348    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
349
350    /// Element-wise less than comparison with a scalar.
351    ///
352    /// # Arguments
353    ///
354    /// * `lhs` - The left-hand side tensor.
355    /// * `rhs` - The right-hand side scalar.
356    ///
357    /// # Returns
358    ///
359    /// The boolean tensor with the result of the comparison.
360    fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
361
362    /// Element-wise less than or equal comparison.
363    ///
364    /// # Arguments
365    ///
366    /// * `lhs` - The left-hand side tensor.
367    /// * `rhs` - The right-hand side tensor.
368    ///
369    /// # Returns
370    ///
371    /// The boolean tensor with the result of the comparison.
372    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
373
374    /// Element-wise less than or equal comparison with a scalar.
375    ///
376    /// # Arguments
377    ///
378    /// * `lhs` - The left-hand side tensor.
379    /// * `rhs` - The right-hand side scalar.
380    ///
381    /// # Returns
382    ///
383    /// The boolean tensor with the result of the comparison.
384    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
385
386    // ====  NUMERIC ==== //
387
388    /// Element-wise addition.
389    ///
390    /// # Arguments
391    ///
392    /// * `lhs` - The left-hand side tensor.
393    /// * `rhs` - The right-hand side tensor.
394    ///
395    /// # Returns
396    ///
397    /// The result of the addition.
398    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
399
400    /// Element-wise addition with a scalar.
401    ///
402    /// # Arguments
403    ///
404    /// * `lhs` - The left-hand side tensor.
405    /// * `rhs` - The right-hand side scalar.
406    ///
407    /// # Returns
408    ///
409    /// The result of the addition.
410    fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
411
412    /// Element-wise power with a IntTensor.
413    ///
414    /// # Arguments
415    ///
416    /// * `lhs` - The left-hand side IntTensor.
417    /// * `rhs` - The right-hand side IntTensor.
418    ///
419    /// # Returns
420    ///
421    /// The elements of `lhs` raised to the power of the elements of `rhs`.
422    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
423        B::float_into_int(B::float_powi(B::int_into_float(lhs), rhs))
424    }
425
426    /// Element-wise power with a floatTensor.
427    ///
428    /// # Arguments
429    ///
430    /// * `lhs` - The left-hand side tensor.
431    /// * `rhs` - The right-hand side floatTensor.
432    ///
433    /// # Returns
434    ///
435    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
436    fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
437        B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
438    }
439
440    /// Element-wise power with a scalar.
441    ///
442    /// # Backend Implementors Note
443    ///
444    /// A number of common exponent cases can be implemented with operations
445    /// which are much cheaper than generic exponentiation.
446    ///
447    /// This (`Backend` impl overridable) operation handles generic optimizations
448    /// for several common integer exponent cases; and then dispatches to
449    /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
450    /// operation to handle the generic case.
451    ///
452    /// # Arguments
453    ///
454    /// * `lhs` - The left-hand side tensor.
455    /// * `rhs` - The right-hand side scalar.
456    ///
457    /// # Returns
458    ///
459    /// The elements of `lhs` raised to the value of `rhs`.
460    fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
461        let exp = rhs.elem::<i32>();
462        match exp {
463            0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
464            1 => lhs,
465            2 => Self::int_mul(lhs.clone(), lhs),
466            _ => Self::int_powi_scalar_impl(lhs, rhs),
467        }
468    }
469
470    /// Element-wise power with a scalar.
471    ///
472    /// # Backend Implementors Note
473    ///
474    /// This is the generic implementation of integer exponentiation
475    /// called by [`Self::int_powi_scalar`] in the fallback case.
476    ///
477    /// By default, this performs a relatively expensive conversion to float,
478    /// exponentiation in float, and conversion back to int.
479    /// This reduces the minimal operation set for `Backend`s,
480    /// at the cost of performance.
481    ///
482    /// This is a good target for specialized optimizations in `Backend` implementations.
483    ///
484    /// As a general rule, this should not be called directly.
485    ///
486    /// # Arguments
487    ///
488    /// * `lhs` - The left-hand side tensor.
489    /// * `rhs` - The right-hand side scalar.
490    ///
491    /// # Returns
492    ///
493    /// The elements of `lhs` raised to the value of `rhs`.
494    fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
495        B::float_into_int(B::float_powi_scalar_impl(B::int_into_float(lhs), rhs))
496    }
497
498    /// Element-wise power with a floatTensor.
499    ///
500    /// Handles a number of special cases, then calls [`Self::int_powf_scalar_impl`].
501    ///
502    /// # Arguments
503    ///
504    /// * `lhs` - The left-hand side tensor.
505    /// * `rhs` - The right-hand side scalar.
506    ///
507    /// # Returns
508    ///
509    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
510    fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
511        if num_traits::Float::floor(rhs) == rhs {
512            let exp = B::IntElem::from_elem(rhs as i32);
513            Self::int_powi_scalar(lhs, exp)
514        } else {
515            Self::int_powf_scalar_impl(lhs, rhs)
516        }
517    }
518
519    /// Element-wise power with a floatTensor.
520    ///
521    /// Fallback handler for [`Self::int_powf_scalar`].
522    ///
523    /// # Arguments
524    ///
525    /// * `lhs` - The left-hand side tensor.
526    /// * `rhs` - The right-hand side scalar.
527    ///
528    /// # Returns
529    ///
530    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
531    fn int_powf_scalar_impl(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
532        B::float_into_int(B::float_powf_scalar_impl(B::int_into_float(lhs), rhs))
533    }
534
535    /// Clamps a tensor under a minimum value.
536    ///
537    /// # Arguments
538    ///
539    /// * `tensor` - The tensor to clamp.
540    /// * `min` - The minimum value.
541    ///
542    /// # Returns
543    ///
544    /// The clamped tensor.
545    fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
546        let mask = Self::int_lower_elem(tensor.clone(), min);
547        Self::int_mask_fill(tensor, mask, min)
548    }
549
550    /// Clamps a tensor over a maximum value.
551    ///
552    /// # Arguments
553    ///
554    /// * `tensor` - The tensor to clamp.
555    /// * `max` - The maximum value.
556    ///
557    /// # Returns
558    ///
559    /// The clamped tensor.
560    fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
561        let mask = Self::int_greater_elem(tensor.clone(), max);
562        Self::int_mask_fill(tensor, mask, max)
563    }
564
565    /// Clamps a tensor between a minimum and maximum value.
566    ///
567    /// # Arguments
568    ///
569    /// * `tensor` - The tensor to clamp.
570    /// * `min` - The minimum value.
571    /// * `max` - The maximum value.
572    ///
573    /// # Returns
574    ///
575    /// The clamped tensor.
576    fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
577        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
578    }
579
580    /// Element-wise subtraction.
581    ///
582    /// # Arguments
583    ///
584    /// * `lhs` - The left-hand side tensor.
585    /// * `rhs` - The right-hand side tensor.
586    ///
587    /// # Returns
588    ///
589    /// The result of the subtraction.
590    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
591
592    /// Element-wise subtraction with a scalar.
593    ///
594    /// # Arguments
595    ///
596    /// * `lhs` - The left-hand side tensor.
597    /// * `rhs` - The right-hand side scalar.
598    ///
599    /// # Returns
600    ///
601    /// The result of the subtraction.
602    fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
603
604    /// Element-wise multiplication.
605    ///
606    /// # Arguments
607    ///
608    /// * `lhs` - The left-hand side tensor.
609    /// * `rhs` - The right-hand side tensor.
610    ///
611    /// # Returns
612    ///
613    /// The result of the multiplication.
614    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
615
616    /// Element-wise multiplication with a scalar.
617    ///
618    /// # Arguments
619    ///
620    /// * `lhs` - The left-hand side tensor.
621    /// * `rhs` - The right-hand side scalar.
622    ///
623    /// # Returns
624    ///
625    /// The result of the multiplication.
626    fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
627
628    /// Element-wise division.
629    ///
630    /// # Arguments
631    ///
632    /// * `lhs` - The left-hand side tensor.
633    /// * `rhs` - The right-hand side tensor.
634    ///
635    /// # Returns
636    ///
637    /// The result of the division.
638    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
639
640    /// Element-wise division with a scalar.
641    ///
642    /// # Arguments
643    ///
644    /// * `lhs` - The left-hand side tensor.
645    /// * `rhs` - The right-hand side scalar.
646    ///
647    /// # Returns
648    ///
649    /// The result of the division.
650    fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
651
652    /// Element-wise modulus.
653    ///
654    /// # Arguments
655    /// * `lhs` - The left-hand side tensor.
656    /// * `rhs` - The right-hand side scalar.
657    ///
658    /// # Returns
659    ///
660    /// The result of applying the modulus of the scalar to the tensor.
661    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
662
663    /// Element-wise modulus with a scalar.
664    ///
665    /// # Arguments
666    /// * `lhs` - The left-hand side tensor.
667    /// * `rhs` - The right-hand side scalar.
668    ///
669    /// # Returns
670    ///
671    /// The result of applying the modulus of the scalar to the tensor.
672    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
673
674    /// Multiplies two tensors together using matrix multiplication.
675    ///
676    /// # Arguments
677    ///
678    /// * `lhs` - The left-hand side tensor.
679    /// * `rhs` - The right-hand side tensor.
680    ///
681    /// # Returns
682    ///
683    /// The result of multiplying the two tensors together using matrix multiplication.
684    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
685
686    /// Element-wise negation.
687    ///
688    /// # Arguments
689    ///
690    /// * `tensor` - The tensor to negate.
691    ///
692    /// # Returns
693    ///
694    /// The negated tensor.
695    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
696        Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
697    }
698
699    /// Creates a tensor of zeros.
700    ///
701    /// # Arguments
702    ///
703    /// * `shape` - The shape of the tensor.
704    /// * `device` - The device to create the tensor on.
705    /// * `dtype` - The target data type.
706    ///
707    /// # Returns
708    ///
709    /// The tensor of zeros.
710    fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
711        Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
712    }
713
714    /// Creates a tensor of ones.
715    ///
716    /// # Arguments
717    ///
718    /// * `shape` - The shape of the tensor.
719    /// * `device` - The device to create the tensor on.
720    /// * `dtype` - The target data type.
721    ///
722    /// # Returns
723    ///
724    /// The tensor of ones.
725    fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
726        Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
727    }
728
729    /// Creates a tensor filled with given value.
730    ///
731    /// # Arguments
732    ///
733    /// * `shape` - The shape of the tensor.
734    /// * `fill_value` - The value with which to fill the tensor.
735    /// * `device` - The device to create the tensor on.
736    /// * `dtype` - The target data type.
737    ///
738    /// # Returns
739    ///
740    /// The tensor filled with given value
741    fn int_full(
742        shape: Shape,
743        fill_value: IntElem<B>,
744        device: &Device<B>,
745        dtype: IntDType,
746    ) -> IntTensor<B> {
747        Self::int_from_data(
748            TensorData::full_dtype(shape, fill_value, dtype.into()),
749            device,
750        )
751    }
752
753    /// Sums all elements in the tensor.
754    ///
755    /// # Arguments
756    ///
757    /// * `tensor` - The tensor to sum.
758    ///
759    /// # Returns
760    ///
761    /// The sum of all elements in the tensor.
762    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
763
764    /// Sums all elements in the tensor along a dimension.
765    ///
766    /// # Arguments
767    ///
768    /// * `tensor` - The tensor to sum.
769    /// * `dim` - The dimension to sum along.
770    ///
771    /// # Returns
772    ///
773    /// The sum of all elements in the tensor along the dimension.
774    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
775
776    /// Computes the product of all elements in the tensor.
777    ///
778    /// # Arguments
779    ///
780    /// * `tensor` - The tensor to compute the product of.
781    ///
782    /// # Returns
783    ///
784    /// The product of all elements in the tensor.
785    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
786
787    /// Computes the product of all elements in the tensor along a dimension.
788    ///
789    /// # Arguments
790    ///
791    /// * `tensor` - The tensor to compute the product of.
792    /// * `dim` - The dimension to compute the product along.
793    ///
794    /// # Returns
795    ///
796    /// The product of all elements in the tensor along the dimension.
797    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
798
799    /// Computes the mean of all elements in the tensor.
800    ///
801    /// # Arguments
802    ///
803    /// * `tensor` - The tensor to compute the mean of.
804    ///
805    /// # Returns
806    ///
807    /// The mean of all elements in the tensor.
808    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
809        let num_elems = tensor.shape().num_elements();
810        B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
811    }
812
813    /// Computes the mean of all elements in the tensor along a dimension.
814    ///
815    /// # Arguments
816    ///
817    /// * `tensor` - The tensor to compute the mean of.
818    ///
819    /// # Returns
820    ///
821    /// The mean of all elements in the tensor along the dimension.
822    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
823
824    /// Computes the cumulative sum of elements along a dimension.
825    ///
826    /// # Arguments
827    ///
828    /// * `tensor` - The tensor to compute the cumulative sum of.
829    /// * `dim` - The dimension along which to compute the cumulative sum.
830    ///
831    /// # Returns
832    ///
833    /// A tensor with the same shape where each element is the cumulative sum
834    /// of all elements up to and including that position along the dimension.
835    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
836
837    /// Computes the cumulative product of elements along a dimension.
838    ///
839    /// # Arguments
840    ///
841    /// * `tensor` - The tensor to compute the cumulative product of.
842    /// * `dim` - The dimension along which to compute the cumulative product.
843    ///
844    /// # Returns
845    ///
846    /// A tensor with the same shape where each element is the cumulative product
847    /// of all elements up to and including that position along the dimension.
848    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
849
850    /// Computes the cumulative minimum of elements along a dimension.
851    ///
852    /// # Arguments
853    ///
854    /// * `tensor` - The tensor to compute the cumulative minimum of.
855    /// * `dim` - The dimension along which to compute the cumulative minimum.
856    ///
857    /// # Returns
858    ///
859    /// A tensor with the same shape where each element is the minimum
860    /// of all elements up to and including that position along the dimension.
861    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
862
863    /// Computes the cumulative maximum of elements along a dimension.
864    ///
865    /// # Arguments
866    ///
867    /// * `tensor` - The tensor to compute the cumulative maximum of.
868    /// * `dim` - The dimension along which to compute the cumulative maximum.
869    ///
870    /// # Returns
871    ///
872    /// A tensor with the same shape where each element is the maximum
873    /// of all elements up to and including that position along the dimension.
874    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
875
876    /// Gets the indices of the maximum elements along a dimension.
877    ///
878    /// # Arguments
879    ///
880    /// * `tensor` - The tensor to get the maximum indices of.
881    /// * `dim` - The dimension to get the maximum indices along.
882    ///
883    /// # Returns
884    ///
885    /// The indices of the maximum elements along the dimension.
886    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
887
888    /// Gets the indices of the minimum elements along a dimension.
889    ///
890    /// # Arguments
891    ///
892    /// * `tensor` - The tensor to get the minimum indices of.
893    /// * `dim` - The dimension to get the minimum indices along.
894    ///
895    /// # Returns
896    ///
897    /// The indices of the minimum elements along the dimension.
898    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
899
900    /// Gets the maximum element in the tensor.
901    ///
902    /// # Arguments
903    ///
904    /// * `tensor` - The tensor to get the maximum element of.
905    ///
906    /// # Returns
907    ///
908    /// The maximum element in the tensor.
909    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
910        let shape = tensor.shape();
911        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
912
913        B::int_max_dim(tensor, 0)
914    }
915
916    /// Gets the maximum element in the tensor along a dimension.
917    ///
918    /// # Arguments
919    ///
920    /// * `tensor` - The tensor to get the maximum element of.
921    /// * `dim` - The dimension to get the maximum element along.
922    ///
923    /// # Returns
924    ///
925    /// The maximum element in the tensor along the dimension.
926    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
927        let index = B::int_argmax(tensor.clone(), dim);
928        B::int_gather(dim, tensor, index)
929    }
930
931    /// Gets the maximum elements and corresponding indices along a dimension.
932    ///
933    /// # Arguments
934    ///
935    /// * `tensor` - The tensor to get the maximum elements and indices of.
936    /// * `dim` - The dimension to get the maximum elements and indices along.
937    ///
938    /// # Returns
939    ///
940    /// The maximum elements and corresponding indices along the dimension.
941    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
942        let index = B::int_argmax(tensor.clone(), dim);
943        let values = B::int_gather(dim, tensor, index.clone());
944
945        (values, index)
946    }
947
948    /// Gets the maximum absolute element in the tensor.
949    ///
950    /// # Arguments
951    ///
952    /// * `tensor` - The tensor to get the maximum element of.
953    ///
954    /// # Returns
955    ///
956    /// The maximum element in the tensor.
957    fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
958        let shape = tensor.shape();
959        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
960
961        B::int_max_abs_dim(tensor, 0)
962    }
963
964    /// Gets the maximum absolute element in the tensor along a dimension.
965    ///
966    /// # Arguments
967    ///
968    /// * `tensor` - The tensor to get the maximum element of.
969    /// * `dim` - The dimension to get the maximum element along.
970    ///
971    /// # Returns
972    ///
973    /// The maximum element in the tensor along the dimension.
974    fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
975        B::int_max_dim(B::int_abs(tensor), dim)
976    }
977
978    /// Gets the minimum element in the tensor.
979    ///
980    /// # Arguments
981    ///
982    /// * `tensor` - The tensor to get the minimum element of.
983    ///
984    /// # Returns
985    ///
986    /// The minimum element in the tensor.
987    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
988        let shape = tensor.shape();
989        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
990
991        B::int_min_dim(tensor, 0)
992    }
993
994    /// Gets the minimum elements in the tensor along a dimension.
995    ///
996    /// # Arguments
997    ///
998    /// * `tensor` - The tensor to get the minimum element of.
999    /// * `dim` - The dimension to get the minimum element along.
1000    ///
1001    /// # Returns
1002    ///
1003    /// The minimum element in the tensor along the dimension.
1004    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1005        let index = B::int_argmin(tensor.clone(), dim);
1006        B::int_gather(dim, tensor, index)
1007    }
1008
1009    /// Gets the minimum elements and corresponding indices along a dimension.
1010    ///
1011    /// # Arguments
1012    ///
1013    /// * `tensor` - The tensor to get the minimum elements and indices of.
1014    /// * `dim` - The dimension to get the minimum elements and indices along.
1015    ///
1016    /// # Returns
1017    ///
1018    /// The minimum elements and corresponding indices along the dimension.
1019    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1020        let indices = B::int_argmin(tensor.clone(), dim);
1021        let values = B::int_gather(dim, tensor, indices.clone());
1022
1023        (values, indices)
1024    }
1025
1026    /// Returns a new tensor with absolute values.
1027    ///
1028    /// # Arguments
1029    ///
1030    /// * `tensor` - The tensor to take absolute value of.
1031    ///
1032    /// # Returns
1033    ///
1034    /// A tensor with the same shape as `tensor` with absolute values.
1035    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1036
1037    /// Transposes an int tensor.
1038    ///
1039    /// # Arguments
1040    ///
1041    /// * `tensor` - The tensor to transpose.
1042    ///
1043    /// # Returns
1044    ///
1045    /// The transposed tensor.
1046    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1047        let ndims = tensor.shape().num_dims();
1048        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1049    }
1050
1051    /// Swaps two dimensions of an int tensor.
1052    ///
1053    /// # Arguments
1054    ///
1055    /// * `tensor` - The tensor to swap the dimensions of.
1056    /// * `dim1` - The first dimension to swap.
1057    /// * `dim2` - The second dimension to swap.
1058    ///
1059    /// # Returns
1060    ///
1061    /// The tensor with the dimensions swapped.
1062    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1063
1064    /// Permutes the dimensions of a tensor.
1065    ///
1066    /// # Arguments
1067    ///
1068    /// * `tensor` - The tensor to permute the dimensions of.
1069    /// * `axes` - The new order of the dimensions.
1070    /// # Returns
1071    ///
1072    /// The tensor with the dimensions permuted.
1073    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1074
1075    /// Reverse the order of elements in a tensor along the given axes.
1076    ///
1077    /// # Arguments
1078    ///
1079    /// * `tensor` - The tensor to reverse.
1080    /// * `axes` - The axes to reverse.
1081    ///
1082    /// The tensor with the elements reversed.
1083    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1084
1085    /// Creates a new int tensor with random values.
1086    ///
1087    ///  # Arguments
1088    ///  * `shape` - The shape of the tensor.
1089    ///  * `distribution` - The distribution to sample from.
1090    ///  * `device` - The device to create the tensor on.
1091    ///
1092    ///  # Returns
1093    ///
1094    ///  The tensor with the given shape and random values.
1095    fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
1096
1097    /// Creates a new tensor with values from the given range with the given step size.
1098    ///
1099    /// # Arguments
1100    ///
1101    /// * `range` - The range of values.
1102    /// * `step` - The step size.
1103    /// * `device` - The device to create the tensor on.
1104    ///
1105    /// # Returns
1106    ///
1107    /// The tensor with the given values.
1108    fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1109        let value = range
1110            .step_by(step)
1111            .map(|i| i.elem())
1112            .collect::<Vec<IntElem<B>>>();
1113        let shape = Shape::new([value.len()]);
1114        let data = TensorData::new(value, shape);
1115        B::int_from_data(data, device)
1116    }
1117
1118    /// Creates a new tensor with values from the given range.
1119    ///
1120    /// # Arguments
1121    ///
1122    /// * `range` - The range of values.
1123    /// * `device` - The device to create the tensor on.
1124    ///
1125    /// # Returns
1126    ///
1127    /// The tensor with the given values.
1128    ///
1129    /// # Remarks
1130    ///
1131    /// Uses `arange_step` with a step size of 1 under the hood.
1132    fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1133        Self::int_arange_step(range, 1, device)
1134    }
1135
1136    /// Tests if any element in the int `tensor` evaluates to True.
1137    ///
1138    /// # Arguments
1139    ///
1140    /// * `tensor` - The tensor to test.
1141    ///
1142    /// # Returns
1143    ///
1144    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1145    fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1146        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1147        let bool_tensor = B::bool_not(bool_tensor);
1148        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1149        B::int_greater_elem(sum, 0.elem())
1150    }
1151
1152    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1153    ///
1154    /// # Arguments
1155    ///
1156    /// * `tensor` - The tensor to test.
1157    /// * `dim` - The axis along which to test.
1158    ///
1159    /// # Returns
1160    ///
1161    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1162    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1163    /// evaluates to True, False otherwise.
1164    fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1165        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1166        let bool_tensor = B::bool_not(bool_tensor);
1167        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1168        B::int_greater_elem(sum, 0.elem())
1169    }
1170
1171    /// Tests if all elements in the int `tensor` evaluate to True.
1172    ///
1173    /// # Arguments
1174    ///
1175    /// * `tensor` - The tensor to test.
1176    ///
1177    /// # Returns
1178    ///
1179    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1180    /// evaluate to True, False otherwise.
1181    fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1182        let num_elems = tensor.shape().num_elements();
1183        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1184        let bool_tensor = B::bool_not(bool_tensor);
1185        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1186        B::int_equal_elem(sum, (num_elems as i32).elem())
1187    }
1188
1189    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1190    ///
1191    /// # Arguments
1192    ///
1193    /// * `tensor` - The tensor to test.
1194    /// * `dim` - The axis along which to test.
1195    ///
1196    /// # Returns
1197    ///
1198    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1199    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1200    /// evaluates to True, False otherwise.
1201    fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1202        let num_elems = tensor.shape().dims[dim];
1203        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1204        let bool_tensor = B::bool_not(bool_tensor);
1205        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1206        B::int_equal_elem(sum, (num_elems as i32).elem())
1207    }
1208
1209    /// Returns the signs of the int `tensor`.
1210    ///
1211    /// # Arguments
1212    ///
1213    /// * `tensor` - The tensor to extract the signs from.
1214    ///
1215    /// # Returns
1216    ///
1217    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1218    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1219        let dtype = tensor.dtype();
1220        let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor), dtype.into());
1221        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1222        let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1223
1224        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1225        result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1226        result
1227    }
1228
1229    /// Broadcasts the int `tensor` to the given `shape`.
1230    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1231
1232    /// Sort the elements of the input `tensor` by value along a given dimension.
1233    ///
1234    /// This sort is unstable (i.e., may reorder equal elements).
1235    ///
1236    /// # Arguments
1237    ///
1238    /// * `tensor` - The input tensor.
1239    /// * `dim` - The axis along which to sort.
1240    /// * `descending` - The sorting order.
1241    ///
1242    /// # Returns
1243    ///
1244    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1245    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1246        sort::<B, Int>(tensor, dim, descending)
1247    }
1248
1249    /// Sort the elements of the input `tensor` by value along a given dimension.
1250    ///
1251    /// This sort is unstable (i.e., may reorder equal elements).
1252    ///
1253    /// # Arguments
1254    ///
1255    /// * `tensor` - The input tensor.
1256    /// * `dim` - The axis along which to sort.
1257    ///
1258    /// # Returns
1259    ///
1260    /// A tensor with the same shape as the input tensor and corresponding indices, where
1261    /// the elements are sorted by value and the indices map back to the original input tensor.
1262    fn int_sort_with_indices(
1263        tensor: IntTensor<B>,
1264        dim: usize,
1265        descending: bool,
1266    ) -> (IntTensor<B>, IntTensor<B>) {
1267        sort_with_indices::<B, Int>(tensor, dim, descending)
1268    }
1269
1270    /// Returns the indices that sort the elements of the input `tensor` by value
1271    /// along a given dimension.
1272    ///
1273    /// This sort is unstable (i.e., may reorder equal elements).
1274    ///
1275    /// # Arguments
1276    ///
1277    /// * `tensor` - The input tensor.
1278    /// * `dim` - The axis along which to sort.
1279    /// * `descending` - The sorting order.
1280    ///
1281    /// # Returns
1282    ///
1283    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1284    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1285        argsort::<B, Int>(tensor, dim, descending)
1286    }
1287
1288    /// Bitwise AND operation for Int Tensors
1289    fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1290
1291    /// Bitwise AND operation for Int Tensors with a scalar
1292    fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1293
1294    /// Bitwise OR operation for Int Tensors
1295    fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1296
1297    /// Bitwise OR operation for Int Tensors with a scalar
1298    fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1299
1300    /// Bitwise XOR operation for Int Tensors
1301    fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1302
1303    /// Bitwise XOR operation for Int Tensors with a scalar
1304    fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1305
1306    /// Bitwise NOT operation for Int Tensors
1307    fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1308
1309    /// Bitwise left shift operation for Int Tensors
1310    fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1311
1312    /// Bitwise left shift operation for Int Tensors with a scalar
1313    fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1314
1315    /// Bitwise right shift operation for Int Tensors
1316    fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1317
1318    /// Bitwise right shift operation for Int Tensors with a scalar
1319    fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1320
1321    /// Converts a tensor to another integer data type.
1322    ///
1323    /// # Arguments
1324    ///
1325    /// * `tensor` - The tensor to convert.
1326    /// * `dtype` - The target data type.
1327    ///
1328    /// # Returns
1329    ///
1330    /// A tensor with the same values as `tensor` but in the target integer data type.
1331    fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1332
1333    /// Unfold windows along a dimension.
1334    ///
1335    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1336    /// where windows are advanced by `step` at each index.
1337    ///
1338    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1339    ///
1340    /// # Arguments
1341    ///
1342    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1343    /// * `dim` - the selected dim.
1344    /// * `size` - the size of each unfolded window.
1345    /// * `step` - the step between each window.
1346    ///
1347    /// # Returns
1348    ///
1349    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1350    fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1351}