burn_backend/backend/ops/
int_tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::sort::{argsort, sort, sort_with_indices};
4use crate::ExecutionError;
5use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};
6use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};
7use alloc::vec::Vec;
8use burn_std::{IntDType, Shape, Slice};
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16    /// Creates a new int tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    /// * `dtype` - The target data type.
23    ///
24    /// # Returns
25    ///
26    /// The integer tensor with the given shape.
27    fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
28
29    /// Converts the tensor to a data structure.
30    ///
31    /// # Arguments
32    ///
33    /// * `tensor` - The tensor.
34    ///
35    /// # Returns
36    ///
37    /// The data structure with the tensor's data.
38    fn int_into_data(
39        tensor: IntTensor<B>,
40    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
41
42    /// Creates a tensor from the data structure.
43    ///
44    /// # Arguments
45    ///
46    /// * `data` - The data structure.
47    /// * `device` - The device to create the tensor on.
48    ///
49    /// # Returns
50    ///
51    /// The tensor with the data.
52    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
53
54    /// Gets the device of the tensor.
55    ///
56    /// # Arguments
57    ///
58    /// * `tensor` - The tensor.
59    ///
60    /// # Returns
61    ///
62    /// The device of the tensor.
63    fn int_device(tensor: &IntTensor<B>) -> Device<B>;
64
65    /// Moves the tensor to the given device.
66    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
67
68    /// Reshapes the tensor.
69    ///
70    /// # Arguments
71    ///
72    /// * `tensor` - The tensor.
73    /// * `shape` - The new shape.
74    ///
75    /// # Returns
76    ///
77    /// The tensor with the new shape.
78    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
79
80    /// Gets the element at the given indices.
81    ///
82    /// # Arguments
83    ///
84    /// * `tensor` - The tensor.
85    /// * `slices` - The slices specifying ranges and steps for each dimension.
86    ///
87    /// # Returns
88    ///
89    /// The elements at the given indices.
90    ///
91    /// # Note
92    ///
93    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
94    /// be passed to this method. Backend implementations do not need to handle empty slices.
95    fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;
96
97    /// Sets the values in the tensor for the given ranges.
98    ///
99    /// # Arguments
100    ///
101    /// * `tensor` - The tensor.
102    /// * `ranges` - The ranges to set the values for.
103    ///
104    /// # Returns
105    ///
106    /// The tensor with the values set for the given ranges.
107    ///
108    /// # Note
109    ///
110    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
111    /// high-level tensor API and will not be passed to this method. Backend implementations do
112    /// not need to handle empty slice assignments.
113    fn int_slice_assign(
114        tensor: IntTensor<B>,
115        slices: &[Slice],
116        value: IntTensor<B>,
117    ) -> IntTensor<B>;
118
119    /// Converts int tensor to float tensor.
120    ///
121    /// # Arguments
122    ///
123    /// * `tensor` - The tensor.
124    ///
125    /// # Returns
126    ///
127    /// The int tensor with the same data as the float tensor.
128    fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
129
130    /// Fills the tensor with values from the source tensor if the mask is true at the given
131    /// indices.
132    ///
133    /// # Arguments
134    ///
135    /// * `tensor` - The tensor.
136    /// * `mask` - The mask.
137    /// * `source` - The source tensor.
138    ///
139    /// # Returns
140    ///
141    /// The tensor with the values filled.
142    fn int_mask_where(
143        tensor: IntTensor<B>,
144        mask: BoolTensor<B>,
145        source: IntTensor<B>,
146    ) -> IntTensor<B>;
147
148    /// Fills the tensor with the given value if the mask is true at the given indices.
149    ///
150    /// # Arguments
151    ///
152    /// * `tensor` - The tensor.
153    /// * `mask` - The mask.
154    /// * `value` - The value.
155    ///
156    /// # Returns
157    ///
158    /// The tensor with the values filled.
159    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
160
161    /// Gather elements from the tensor at the given indices.
162    ///
163    /// # Arguments
164    ///
165    /// * `dim` - The dimension to gather from.
166    /// * `tensor` - The tensor.
167    /// * `indices` - The indices.
168    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
169
170    /// Scatter a given value to the tensor at the given indices using sum reduction.
171    ///
172    /// # Arguments
173    ///
174    /// * `dim` - The dimension to scatter to.
175    /// * `tensor` - The tensor.
176    /// * `indices` - The indices.
177    /// * `value` - The value.
178    ///
179    /// # Returns
180    ///
181    /// The tensor with the values scattered.
182    fn int_scatter_add(
183        dim: usize,
184        tensor: IntTensor<B>,
185        indices: IntTensor<B>,
186        value: IntTensor<B>,
187    ) -> IntTensor<B>;
188
189    /// Select tensor elements along the given dimension corresponding to the given indices.
190    ///
191    /// # Arguments
192    ///
193    /// * `tensor` - The tensor.
194    /// * `dim` - The dimension to select from.
195    /// * `indices` - The indices.
196    ///
197    /// # Returns
198    ///
199    /// The tensor with the selected elements.
200    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
201
202    /// Assign the selected elements along the given dimension corresponding to the given indices
203    /// to the given value using sum reduction.
204    ///
205    /// # Arguments
206    ///
207    /// * `tensor` - The tensor.
208    /// * `dim` - The dimension to select from.
209    /// * `indices` - The indices.
210    /// * `value` - The value.
211    ///
212    /// # Returns
213    ///
214    /// The tensor with the selected elements assigned to the given value.
215    fn int_select_add(
216        tensor: IntTensor<B>,
217        dim: usize,
218        indices: IntTensor<B>,
219        value: IntTensor<B>,
220    ) -> IntTensor<B>;
221
222    /// Repeats the tensor along the given dimension the given number of times.
223    ///
224    /// # Arguments
225    ///
226    /// * `tensor` - The tensor.
227    /// * `dim` - The dimension to repeat.
228    /// * `times` - The number of times to repeat.
229    ///
230    /// # Returns
231    ///
232    /// The tensor with the given dimension repeated the given number of times.
233    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
234        repeat_with_slice_assign::<B, Int>(tensor, dim, times)
235    }
236
237    /// Concatenates the given tensors along the given dimension.
238    ///
239    /// # Arguments
240    ///
241    /// * `tensors` - The tensors.
242    /// * `dim` - The dimension to concatenate along.
243    ///
244    /// # Returns
245    ///
246    /// The concatenated tensor.
247    ///
248    /// # Note
249    ///
250    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
251    /// high-level tensor API and will not be passed to this method. Backend implementations do
252    /// not need to handle empty tensors.
253    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
254        cat_with_slice_assign::<B, Int>(tensors, dim)
255    }
256
257    /// Element-wise equality comparison.
258    ///
259    /// # Arguments
260    ///
261    /// * `lhs` - The left-hand side tensor.
262    /// * `rhs` - The right-hand side tensor.
263    ///
264    /// # Returns
265    ///
266    /// The boolean tensor with the result of the comparison.
267    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
268
269    /// Element-wise non-equality comparison.
270    ///
271    /// # Arguments
272    ///
273    /// * `lhs` - The left-hand side tensor.
274    /// * `rhs` - The right-hand side tensor.
275    ///
276    /// # Returns
277    ///
278    /// The boolean tensor with the result of the comparison.
279    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
280        let equal_tensor = B::int_equal(lhs, rhs);
281        B::bool_not(equal_tensor)
282    }
283
284    /// Element-wise equality comparison with a scalar.
285    ///
286    /// # Arguments
287    ///
288    /// * `lhs` - The left-hand side tensor.
289    /// * `rhs` - The right-hand side scalar.
290    ///
291    /// # Returns
292    ///
293    /// The boolean tensor with the result of the comparison.
294    fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
295
296    /// Element-wise non-equality comparison with a scalar.
297    ///
298    /// # Arguments
299    ///
300    /// * `lhs` - The left-hand side tensor.
301    /// * `rhs` - The right-hand side scalar.
302    ///
303    /// # Returns
304    ///
305    /// The boolean tensor with the result of the comparison.
306    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
307        let equal_tensor = B::int_equal_elem(lhs, rhs);
308        B::bool_not(equal_tensor)
309    }
310
311    /// Element-wise greater than comparison.
312    ///
313    /// # Arguments
314    ///
315    /// * `lhs` - The left-hand side tensor.
316    /// * `rhs` - The right-hand side tensor.
317    ///
318    /// # Returns
319    ///
320    /// The boolean tensor with the result of the comparison.
321    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
322
323    /// Element-wise greater than comparison with a scalar.
324    ///
325    /// # Arguments
326    ///
327    /// * `lhs` - The left-hand side tensor.
328    /// * `rhs` - The right-hand side scalar.
329    ///
330    /// # Returns
331    ///
332    /// The boolean tensor with the result of the comparison.
333    fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
334
335    /// Element-wise greater than or equal comparison.
336    ///
337    /// # Arguments
338    ///
339    /// * `lhs` - The left-hand side tensor.
340    /// * `rhs` - The right-hand side tensor.
341    ///
342    /// # Returns
343    ///
344    /// The boolean tensor with the result of the comparison.
345    fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
346
347    /// Element-wise greater than or equal comparison with a scalar.
348    ///
349    /// # Arguments
350    ///
351    /// * `lhs` - The left-hand side tensor.
352    /// * `rhs` - The right-hand side scalar.
353    ///
354    /// # Returns
355    ///
356    /// The boolean tensor with the result of the comparison.
357    fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
358
359    /// Element-wise less than comparison.
360    ///
361    /// # Arguments
362    ///
363    /// * `lhs` - The left-hand side tensor.
364    /// * `rhs` - The right-hand side tensor.
365    ///
366    /// # Returns
367    ///
368    /// The boolean tensor with the result of the comparison.
369    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
370
371    /// Element-wise less than comparison with a scalar.
372    ///
373    /// # Arguments
374    ///
375    /// * `lhs` - The left-hand side tensor.
376    /// * `rhs` - The right-hand side scalar.
377    ///
378    /// # Returns
379    ///
380    /// The boolean tensor with the result of the comparison.
381    fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
382
383    /// Element-wise less than or equal comparison.
384    ///
385    /// # Arguments
386    ///
387    /// * `lhs` - The left-hand side tensor.
388    /// * `rhs` - The right-hand side tensor.
389    ///
390    /// # Returns
391    ///
392    /// The boolean tensor with the result of the comparison.
393    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
394
395    /// Element-wise less than or equal comparison with a scalar.
396    ///
397    /// # Arguments
398    ///
399    /// * `lhs` - The left-hand side tensor.
400    /// * `rhs` - The right-hand side scalar.
401    ///
402    /// # Returns
403    ///
404    /// The boolean tensor with the result of the comparison.
405    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
406
407    // ====  NUMERIC ==== //
408
409    /// Element-wise addition.
410    ///
411    /// # Arguments
412    ///
413    /// * `lhs` - The left-hand side tensor.
414    /// * `rhs` - The right-hand side tensor.
415    ///
416    /// # Returns
417    ///
418    /// The result of the addition.
419    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
420
421    /// Element-wise addition with a scalar.
422    ///
423    /// # Arguments
424    ///
425    /// * `lhs` - The left-hand side tensor.
426    /// * `rhs` - The right-hand side scalar.
427    ///
428    /// # Returns
429    ///
430    /// The result of the addition.
431    fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
432
433    /// Element-wise power with a IntTensor.
434    ///
435    /// # Arguments
436    ///
437    /// * `lhs` - The left-hand side IntTensor.
438    /// * `rhs` - The right-hand side IntTensor.
439    ///
440    /// # Returns
441    ///
442    /// The elements of `lhs` raised to the power of the elements of `rhs`.
443    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
444        B::float_into_int(B::float_powi(B::int_into_float(lhs), rhs))
445    }
446
447    /// Element-wise power with a floatTensor.
448    ///
449    /// # Arguments
450    ///
451    /// * `lhs` - The left-hand side tensor.
452    /// * `rhs` - The right-hand side floatTensor.
453    ///
454    /// # Returns
455    ///
456    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
457    fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
458        B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
459    }
460
461    /// Element-wise power with a scalar.
462    ///
463    /// # Backend Implementors Note
464    ///
465    /// A number of common exponent cases can be implemented with operations
466    /// which are much cheaper than generic exponentiation.
467    ///
468    /// This (`Backend` impl overridable) operation handles generic optimizations
469    /// for several common integer exponent cases; and then dispatches to
470    /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
471    /// operation to handle the generic case.
472    ///
473    /// # Arguments
474    ///
475    /// * `lhs` - The left-hand side tensor.
476    /// * `rhs` - The right-hand side scalar.
477    ///
478    /// # Returns
479    ///
480    /// The elements of `lhs` raised to the value of `rhs`.
481    fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
482        let exp = rhs.elem::<i32>();
483        match exp {
484            0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
485            1 => lhs,
486            2 => Self::int_mul(lhs.clone(), lhs),
487            _ => Self::int_powi_scalar_impl(lhs, rhs),
488        }
489    }
490
491    /// Element-wise power with a scalar.
492    ///
493    /// # Backend Implementors Note
494    ///
495    /// This is the generic implementation of integer exponentiation
496    /// called by [`Self::int_powi_scalar`] in the fallback case.
497    ///
498    /// By default, this performs a relatively expensive conversion to float,
499    /// exponentiation in float, and conversion back to int.
500    /// This reduces the minimal operation set for `Backend`s,
501    /// at the cost of performance.
502    ///
503    /// This is a good target for specialized optimizations in `Backend` implementations.
504    ///
505    /// As a general rule, this should not be called directly.
506    ///
507    /// # Arguments
508    ///
509    /// * `lhs` - The left-hand side tensor.
510    /// * `rhs` - The right-hand side scalar.
511    ///
512    /// # Returns
513    ///
514    /// The elements of `lhs` raised to the value of `rhs`.
515    fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
516        B::float_into_int(B::float_powi_scalar_impl(B::int_into_float(lhs), rhs))
517    }
518
519    /// Element-wise power with a floatTensor.
520    ///
521    /// Handles a number of special cases, then calls [`Self::int_powf_scalar_impl`].
522    ///
523    /// # Arguments
524    ///
525    /// * `lhs` - The left-hand side tensor.
526    /// * `rhs` - The right-hand side scalar.
527    ///
528    /// # Returns
529    ///
530    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
531    fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
532        if num_traits::Float::floor(rhs) == rhs {
533            let exp = B::IntElem::from_elem(rhs as i32);
534            Self::int_powi_scalar(lhs, exp)
535        } else {
536            Self::int_powf_scalar_impl(lhs, rhs)
537        }
538    }
539
540    /// Element-wise power with a floatTensor.
541    ///
542    /// Fallback handler for [`Self::int_powf_scalar`].
543    ///
544    /// # Arguments
545    ///
546    /// * `lhs` - The left-hand side tensor.
547    /// * `rhs` - The right-hand side scalar.
548    ///
549    /// # Returns
550    ///
551    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
552    fn int_powf_scalar_impl(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
553        B::float_into_int(B::float_powf_scalar_impl(B::int_into_float(lhs), rhs))
554    }
555
556    /// Clamps a tensor under a minimum value.
557    ///
558    /// # Arguments
559    ///
560    /// * `tensor` - The tensor to clamp.
561    /// * `min` - The minimum value.
562    ///
563    /// # Returns
564    ///
565    /// The clamped tensor.
566    fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
567        let mask = Self::int_lower_elem(tensor.clone(), min);
568        Self::int_mask_fill(tensor, mask, min)
569    }
570
571    /// Clamps a tensor over a maximum value.
572    ///
573    /// # Arguments
574    ///
575    /// * `tensor` - The tensor to clamp.
576    /// * `max` - The maximum value.
577    ///
578    /// # Returns
579    ///
580    /// The clamped tensor.
581    fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
582        let mask = Self::int_greater_elem(tensor.clone(), max);
583        Self::int_mask_fill(tensor, mask, max)
584    }
585
586    /// Clamps a tensor between a minimum and maximum value.
587    ///
588    /// # Arguments
589    ///
590    /// * `tensor` - The tensor to clamp.
591    /// * `min` - The minimum value.
592    /// * `max` - The maximum value.
593    ///
594    /// # Returns
595    ///
596    /// The clamped tensor.
597    fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
598        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
599    }
600
601    /// Element-wise subtraction.
602    ///
603    /// # Arguments
604    ///
605    /// * `lhs` - The left-hand side tensor.
606    /// * `rhs` - The right-hand side tensor.
607    ///
608    /// # Returns
609    ///
610    /// The result of the subtraction.
611    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
612
613    /// Element-wise subtraction with a scalar.
614    ///
615    /// # Arguments
616    ///
617    /// * `lhs` - The left-hand side tensor.
618    /// * `rhs` - The right-hand side scalar.
619    ///
620    /// # Returns
621    ///
622    /// The result of the subtraction.
623    fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
624
625    /// Element-wise multiplication.
626    ///
627    /// # Arguments
628    ///
629    /// * `lhs` - The left-hand side tensor.
630    /// * `rhs` - The right-hand side tensor.
631    ///
632    /// # Returns
633    ///
634    /// The result of the multiplication.
635    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
636
637    /// Element-wise multiplication with a scalar.
638    ///
639    /// # Arguments
640    ///
641    /// * `lhs` - The left-hand side tensor.
642    /// * `rhs` - The right-hand side scalar.
643    ///
644    /// # Returns
645    ///
646    /// The result of the multiplication.
647    fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
648
649    /// Element-wise division.
650    ///
651    /// # Arguments
652    ///
653    /// * `lhs` - The left-hand side tensor.
654    /// * `rhs` - The right-hand side tensor.
655    ///
656    /// # Returns
657    ///
658    /// The result of the division.
659    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
660
661    /// Element-wise division with a scalar.
662    ///
663    /// # Arguments
664    ///
665    /// * `lhs` - The left-hand side tensor.
666    /// * `rhs` - The right-hand side scalar.
667    ///
668    /// # Returns
669    ///
670    /// The result of the division.
671    fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
672
673    /// Element-wise modulus.
674    ///
675    /// # Arguments
676    /// * `lhs` - The left-hand side tensor.
677    /// * `rhs` - The right-hand side scalar.
678    ///
679    /// # Returns
680    ///
681    /// The result of applying the modulus of the scalar to the tensor.
682    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
683
684    /// Element-wise modulus with a scalar.
685    ///
686    /// # Arguments
687    /// * `lhs` - The left-hand side tensor.
688    /// * `rhs` - The right-hand side scalar.
689    ///
690    /// # Returns
691    ///
692    /// The result of applying the modulus of the scalar to the tensor.
693    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
694
695    /// Multiplies two tensors together using matrix multiplication.
696    ///
697    /// # Arguments
698    ///
699    /// * `lhs` - The left-hand side tensor.
700    /// * `rhs` - The right-hand side tensor.
701    ///
702    /// # Returns
703    ///
704    /// The result of multiplying the two tensors together using matrix multiplication.
705    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
706
707    /// Element-wise negation.
708    ///
709    /// # Arguments
710    ///
711    /// * `tensor` - The tensor to negate.
712    ///
713    /// # Returns
714    ///
715    /// The negated tensor.
716    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
717        Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
718    }
719
720    /// Creates a tensor of zeros.
721    ///
722    /// # Arguments
723    ///
724    /// * `shape` - The shape of the tensor.
725    /// * `device` - The device to create the tensor on.
726    /// * `dtype` - The target data type.
727    ///
728    /// # Returns
729    ///
730    /// The tensor of zeros.
731    fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
732        Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
733    }
734
735    /// Creates a tensor of ones.
736    ///
737    /// # Arguments
738    ///
739    /// * `shape` - The shape of the tensor.
740    /// * `device` - The device to create the tensor on.
741    /// * `dtype` - The target data type.
742    ///
743    /// # Returns
744    ///
745    /// The tensor of ones.
746    fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
747        Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
748    }
749
750    /// Creates a tensor filled with given value.
751    ///
752    /// # Arguments
753    ///
754    /// * `shape` - The shape of the tensor.
755    /// * `fill_value` - The value with which to fill the tensor.
756    /// * `device` - The device to create the tensor on.
757    /// * `dtype` - The target data type.
758    ///
759    /// # Returns
760    ///
761    /// The tensor filled with given value
762    fn int_full(
763        shape: Shape,
764        fill_value: IntElem<B>,
765        device: &Device<B>,
766        dtype: IntDType,
767    ) -> IntTensor<B> {
768        Self::int_from_data(
769            TensorData::full_dtype(shape, fill_value, dtype.into()),
770            device,
771        )
772    }
773
774    /// Sums all elements in the tensor.
775    ///
776    /// # Arguments
777    ///
778    /// * `tensor` - The tensor to sum.
779    ///
780    /// # Returns
781    ///
782    /// The sum of all elements in the tensor.
783    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
784
785    /// Sums all elements in the tensor along a dimension.
786    ///
787    /// # Arguments
788    ///
789    /// * `tensor` - The tensor to sum.
790    /// * `dim` - The dimension to sum along.
791    ///
792    /// # Returns
793    ///
794    /// The sum of all elements in the tensor along the dimension.
795    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
796
797    /// Computes the product of all elements in the tensor.
798    ///
799    /// # Arguments
800    ///
801    /// * `tensor` - The tensor to compute the product of.
802    ///
803    /// # Returns
804    ///
805    /// The product of all elements in the tensor.
806    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
807
808    /// Computes the product of all elements in the tensor along a dimension.
809    ///
810    /// # Arguments
811    ///
812    /// * `tensor` - The tensor to compute the product of.
813    /// * `dim` - The dimension to compute the product along.
814    ///
815    /// # Returns
816    ///
817    /// The product of all elements in the tensor along the dimension.
818    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
819
820    /// Computes the mean of all elements in the tensor.
821    ///
822    /// # Arguments
823    ///
824    /// * `tensor` - The tensor to compute the mean of.
825    ///
826    /// # Returns
827    ///
828    /// The mean of all elements in the tensor.
829    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
830        let num_elems = tensor.shape().num_elements();
831        B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
832    }
833
834    /// Computes the mean of all elements in the tensor along a dimension.
835    ///
836    /// # Arguments
837    ///
838    /// * `tensor` - The tensor to compute the mean of.
839    ///
840    /// # Returns
841    ///
842    /// The mean of all elements in the tensor along the dimension.
843    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
844
845    /// Computes the cumulative sum of elements along a dimension.
846    ///
847    /// # Arguments
848    ///
849    /// * `tensor` - The tensor to compute the cumulative sum of.
850    /// * `dim` - The dimension along which to compute the cumulative sum.
851    ///
852    /// # Returns
853    ///
854    /// A tensor with the same shape where each element is the cumulative sum
855    /// of all elements up to and including that position along the dimension.
856    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
857
858    /// Computes the cumulative product of elements along a dimension.
859    ///
860    /// # Arguments
861    ///
862    /// * `tensor` - The tensor to compute the cumulative product of.
863    /// * `dim` - The dimension along which to compute the cumulative product.
864    ///
865    /// # Returns
866    ///
867    /// A tensor with the same shape where each element is the cumulative product
868    /// of all elements up to and including that position along the dimension.
869    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
870
871    /// Computes the cumulative minimum of elements along a dimension.
872    ///
873    /// # Arguments
874    ///
875    /// * `tensor` - The tensor to compute the cumulative minimum of.
876    /// * `dim` - The dimension along which to compute the cumulative minimum.
877    ///
878    /// # Returns
879    ///
880    /// A tensor with the same shape where each element is the minimum
881    /// of all elements up to and including that position along the dimension.
882    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
883
884    /// Computes the cumulative maximum of elements along a dimension.
885    ///
886    /// # Arguments
887    ///
888    /// * `tensor` - The tensor to compute the cumulative maximum of.
889    /// * `dim` - The dimension along which to compute the cumulative maximum.
890    ///
891    /// # Returns
892    ///
893    /// A tensor with the same shape where each element is the maximum
894    /// of all elements up to and including that position along the dimension.
895    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
896
897    /// Gets the indices of the maximum elements along a dimension.
898    ///
899    /// # Arguments
900    ///
901    /// * `tensor` - The tensor to get the maximum indices of.
902    /// * `dim` - The dimension to get the maximum indices along.
903    ///
904    /// # Returns
905    ///
906    /// The indices of the maximum elements along the dimension.
907    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
908
909    /// Gets the indices of the minimum elements along a dimension.
910    ///
911    /// # Arguments
912    ///
913    /// * `tensor` - The tensor to get the minimum indices of.
914    /// * `dim` - The dimension to get the minimum indices along.
915    ///
916    /// # Returns
917    ///
918    /// The indices of the minimum elements along the dimension.
919    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
920
921    /// Gets the maximum element in the tensor.
922    ///
923    /// # Arguments
924    ///
925    /// * `tensor` - The tensor to get the maximum element of.
926    ///
927    /// # Returns
928    ///
929    /// The maximum element in the tensor.
930    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
931        let shape = tensor.shape();
932        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
933
934        B::int_max_dim(tensor, 0)
935    }
936
937    /// Gets the maximum element in the tensor along a dimension.
938    ///
939    /// # Arguments
940    ///
941    /// * `tensor` - The tensor to get the maximum element of.
942    /// * `dim` - The dimension to get the maximum element along.
943    ///
944    /// # Returns
945    ///
946    /// The maximum element in the tensor along the dimension.
947    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
948        let index = B::int_argmax(tensor.clone(), dim);
949        B::int_gather(dim, tensor, index)
950    }
951
952    /// Gets the maximum elements and corresponding indices along a dimension.
953    ///
954    /// # Arguments
955    ///
956    /// * `tensor` - The tensor to get the maximum elements and indices of.
957    /// * `dim` - The dimension to get the maximum elements and indices along.
958    ///
959    /// # Returns
960    ///
961    /// The maximum elements and corresponding indices along the dimension.
962    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
963        let index = B::int_argmax(tensor.clone(), dim);
964        let values = B::int_gather(dim, tensor, index.clone());
965
966        (values, index)
967    }
968
969    /// Gets the maximum absolute element in the tensor.
970    ///
971    /// # Arguments
972    ///
973    /// * `tensor` - The tensor to get the maximum element of.
974    ///
975    /// # Returns
976    ///
977    /// The maximum element in the tensor.
978    fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
979        let shape = tensor.shape();
980        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
981
982        B::int_max_abs_dim(tensor, 0)
983    }
984
985    /// Gets the maximum absolute element in the tensor along a dimension.
986    ///
987    /// # Arguments
988    ///
989    /// * `tensor` - The tensor to get the maximum element of.
990    /// * `dim` - The dimension to get the maximum element along.
991    ///
992    /// # Returns
993    ///
994    /// The maximum element in the tensor along the dimension.
995    fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
996        B::int_max_dim(B::int_abs(tensor), dim)
997    }
998
999    /// Gets the minimum element in the tensor.
1000    ///
1001    /// # Arguments
1002    ///
1003    /// * `tensor` - The tensor to get the minimum element of.
1004    ///
1005    /// # Returns
1006    ///
1007    /// The minimum element in the tensor.
1008    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
1009        let shape = tensor.shape();
1010        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1011
1012        B::int_min_dim(tensor, 0)
1013    }
1014
1015    /// Gets the minimum elements in the tensor along a dimension.
1016    ///
1017    /// # Arguments
1018    ///
1019    /// * `tensor` - The tensor to get the minimum element of.
1020    /// * `dim` - The dimension to get the minimum element along.
1021    ///
1022    /// # Returns
1023    ///
1024    /// The minimum element in the tensor along the dimension.
1025    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1026        let index = B::int_argmin(tensor.clone(), dim);
1027        B::int_gather(dim, tensor, index)
1028    }
1029
1030    /// Gets the minimum elements and corresponding indices along a dimension.
1031    ///
1032    /// # Arguments
1033    ///
1034    /// * `tensor` - The tensor to get the minimum elements and indices of.
1035    /// * `dim` - The dimension to get the minimum elements and indices along.
1036    ///
1037    /// # Returns
1038    ///
1039    /// The minimum elements and corresponding indices along the dimension.
1040    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1041        let indices = B::int_argmin(tensor.clone(), dim);
1042        let values = B::int_gather(dim, tensor, indices.clone());
1043
1044        (values, indices)
1045    }
1046
1047    /// Returns a new tensor with absolute values.
1048    ///
1049    /// # Arguments
1050    ///
1051    /// * `tensor` - The tensor to take absolute value of.
1052    ///
1053    /// # Returns
1054    ///
1055    /// A tensor with the same shape as `tensor` with absolute values.
1056    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1057
1058    /// Transposes an int tensor.
1059    ///
1060    /// # Arguments
1061    ///
1062    /// * `tensor` - The tensor to transpose.
1063    ///
1064    /// # Returns
1065    ///
1066    /// The transposed tensor.
1067    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1068        let ndims = tensor.shape().num_dims();
1069        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1070    }
1071
1072    /// Swaps two dimensions of an int tensor.
1073    ///
1074    /// # Arguments
1075    ///
1076    /// * `tensor` - The tensor to swap the dimensions of.
1077    /// * `dim1` - The first dimension to swap.
1078    /// * `dim2` - The second dimension to swap.
1079    ///
1080    /// # Returns
1081    ///
1082    /// The tensor with the dimensions swapped.
1083    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1084
1085    /// Permutes the dimensions of a tensor.
1086    ///
1087    /// # Arguments
1088    ///
1089    /// * `tensor` - The tensor to permute the dimensions of.
1090    /// * `axes` - The new order of the dimensions.
1091    /// # Returns
1092    ///
1093    /// The tensor with the dimensions permuted.
1094    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1095
1096    /// Reverse the order of elements in a tensor along the given axes.
1097    ///
1098    /// # Arguments
1099    ///
1100    /// * `tensor` - The tensor to reverse.
1101    /// * `axes` - The axes to reverse.
1102    ///
1103    /// The tensor with the elements reversed.
1104    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1105
1106    /// Creates a new int tensor with random values.
1107    ///
1108    ///  # Arguments
1109    ///  * `shape` - The shape of the tensor.
1110    ///  * `distribution` - The distribution to sample from.
1111    ///  * `device` - The device to create the tensor on.
1112    ///
1113    ///  # Returns
1114    ///
1115    ///  The tensor with the given shape and random values.
1116    fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
1117
1118    /// Creates a new tensor with values from the given range with the given step size.
1119    ///
1120    /// # Arguments
1121    ///
1122    /// * `range` - The range of values.
1123    /// * `step` - The step size.
1124    /// * `device` - The device to create the tensor on.
1125    ///
1126    /// # Returns
1127    ///
1128    /// The tensor with the given values.
1129    fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1130        let value = range
1131            .step_by(step)
1132            .map(|i| i.elem())
1133            .collect::<Vec<IntElem<B>>>();
1134        let shape = Shape::new([value.len()]);
1135        let data = TensorData::new(value, shape);
1136        B::int_from_data(data, device)
1137    }
1138
1139    /// Creates a new tensor with values from the given range.
1140    ///
1141    /// # Arguments
1142    ///
1143    /// * `range` - The range of values.
1144    /// * `device` - The device to create the tensor on.
1145    ///
1146    /// # Returns
1147    ///
1148    /// The tensor with the given values.
1149    ///
1150    /// # Remarks
1151    ///
1152    /// Uses `arange_step` with a step size of 1 under the hood.
1153    fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1154        Self::int_arange_step(range, 1, device)
1155    }
1156
1157    /// Tests if any element in the int `tensor` evaluates to True.
1158    ///
1159    /// # Arguments
1160    ///
1161    /// * `tensor` - The tensor to test.
1162    ///
1163    /// # Returns
1164    ///
1165    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1166    fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1167        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1168        let bool_tensor = B::bool_not(bool_tensor);
1169        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1170        B::int_greater_elem(sum, 0.elem())
1171    }
1172
1173    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1174    ///
1175    /// # Arguments
1176    ///
1177    /// * `tensor` - The tensor to test.
1178    /// * `dim` - The axis along which to test.
1179    ///
1180    /// # Returns
1181    ///
1182    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1183    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1184    /// evaluates to True, False otherwise.
1185    fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1186        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1187        let bool_tensor = B::bool_not(bool_tensor);
1188        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1189        B::int_greater_elem(sum, 0.elem())
1190    }
1191
1192    /// Tests if all elements in the int `tensor` evaluate to True.
1193    ///
1194    /// # Arguments
1195    ///
1196    /// * `tensor` - The tensor to test.
1197    ///
1198    /// # Returns
1199    ///
1200    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1201    /// evaluate to True, False otherwise.
1202    fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1203        let num_elems = tensor.shape().num_elements();
1204        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1205        let bool_tensor = B::bool_not(bool_tensor);
1206        let sum = B::int_sum(B::bool_into_int(bool_tensor));
1207        B::int_equal_elem(sum, (num_elems as i32).elem())
1208    }
1209
1210    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1211    ///
1212    /// # Arguments
1213    ///
1214    /// * `tensor` - The tensor to test.
1215    /// * `dim` - The axis along which to test.
1216    ///
1217    /// # Returns
1218    ///
1219    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1220    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1221    /// evaluates to True, False otherwise.
1222    fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1223        let num_elems = tensor.shape().dims[dim];
1224        let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1225        let bool_tensor = B::bool_not(bool_tensor);
1226        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1227        B::int_equal_elem(sum, (num_elems as i32).elem())
1228    }
1229
1230    /// Returns the signs of the int `tensor`.
1231    ///
1232    /// # Arguments
1233    ///
1234    /// * `tensor` - The tensor to extract the signs from.
1235    ///
1236    /// # Returns
1237    ///
1238    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1239    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1240        let dtype = tensor.dtype();
1241        let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor), dtype.into());
1242        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1243        let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1244
1245        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1246        result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1247        result
1248    }
1249
1250    /// Broadcasts the int `tensor` to the given `shape`.
1251    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1252
1253    /// Sort the elements of the input `tensor` by value along a given dimension.
1254    ///
1255    /// This sort is unstable (i.e., may reorder equal elements).
1256    ///
1257    /// # Arguments
1258    ///
1259    /// * `tensor` - The input tensor.
1260    /// * `dim` - The axis along which to sort.
1261    /// * `descending` - The sorting order.
1262    ///
1263    /// # Returns
1264    ///
1265    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1266    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1267        sort::<B, Int>(tensor, dim, descending)
1268    }
1269
1270    /// Sort the elements of the input `tensor` by value along a given dimension.
1271    ///
1272    /// This sort is unstable (i.e., may reorder equal elements).
1273    ///
1274    /// # Arguments
1275    ///
1276    /// * `tensor` - The input tensor.
1277    /// * `dim` - The axis along which to sort.
1278    ///
1279    /// # Returns
1280    ///
1281    /// A tensor with the same shape as the input tensor and corresponding indices, where
1282    /// the elements are sorted by value and the indices map back to the original input tensor.
1283    fn int_sort_with_indices(
1284        tensor: IntTensor<B>,
1285        dim: usize,
1286        descending: bool,
1287    ) -> (IntTensor<B>, IntTensor<B>) {
1288        sort_with_indices::<B, Int>(tensor, dim, descending)
1289    }
1290
1291    /// Returns the indices that sort the elements of the input `tensor` by value
1292    /// along a given dimension.
1293    ///
1294    /// This sort is unstable (i.e., may reorder equal elements).
1295    ///
1296    /// # Arguments
1297    ///
1298    /// * `tensor` - The input tensor.
1299    /// * `dim` - The axis along which to sort.
1300    /// * `descending` - The sorting order.
1301    ///
1302    /// # Returns
1303    ///
1304    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1305    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1306        argsort::<B, Int>(tensor, dim, descending)
1307    }
1308
1309    /// Bitwise AND operation for Int Tensors
1310    fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1311
1312    /// Bitwise AND operation for Int Tensors with a scalar
1313    fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1314
1315    /// Bitwise OR operation for Int Tensors
1316    fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1317
1318    /// Bitwise OR operation for Int Tensors with a scalar
1319    fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1320
1321    /// Bitwise XOR operation for Int Tensors
1322    fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1323
1324    /// Bitwise XOR operation for Int Tensors with a scalar
1325    fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1326
1327    /// Bitwise NOT operation for Int Tensors
1328    fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1329
1330    /// Bitwise left shift operation for Int Tensors
1331    fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1332
1333    /// Bitwise left shift operation for Int Tensors with a scalar
1334    fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1335
1336    /// Bitwise right shift operation for Int Tensors
1337    fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1338
1339    /// Bitwise right shift operation for Int Tensors with a scalar
1340    fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1341
1342    /// Converts a tensor to another integer data type.
1343    ///
1344    /// # Arguments
1345    ///
1346    /// * `tensor` - The tensor to convert.
1347    /// * `dtype` - The target data type.
1348    ///
1349    /// # Returns
1350    ///
1351    /// A tensor with the same values as `tensor` but in the target integer data type.
1352    fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1353
1354    /// Unfold windows along a dimension.
1355    ///
1356    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1357    /// where windows are advanced by `step` at each index.
1358    ///
1359    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1360    ///
1361    /// # Arguments
1362    ///
1363    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1364    /// * `dim` - the selected dim.
1365    /// * `size` - the size of each unfolded window.
1366    /// * `step` - the step between each window.
1367    ///
1368    /// # Returns
1369    ///
1370    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1371    fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1372}