Skip to main content

burn_backend/backend/ops/
int_tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::sort::{argsort, sort, sort_with_indices};
4use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};
5use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};
6use crate::{ExecutionError, Scalar, get_device_settings};
7use alloc::vec::Vec;
8use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16    /// Creates a new int tensor.
17    ///
18    /// # Arguments
19    ///
20    /// * `shape` - The shape of the tensor.
21    /// * `device` - The device to create the tensor on.
22    /// * `dtype` - The target data type.
23    ///
24    /// # Returns
25    ///
26    /// The integer tensor with the given shape.
27    fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
28
29    /// Converts the tensor to a data structure.
30    ///
31    /// # Arguments
32    ///
33    /// * `tensor` - The tensor.
34    ///
35    /// # Returns
36    ///
37    /// The data structure with the tensor's data.
38    fn int_into_data(
39        tensor: IntTensor<B>,
40    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
41
42    /// Creates a tensor from the data structure.
43    ///
44    /// # Arguments
45    ///
46    /// * `data` - The data structure.
47    /// * `device` - The device to create the tensor on.
48    ///
49    /// # Returns
50    ///
51    /// The tensor with the data.
52    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
53
54    /// Gets the device of the tensor.
55    ///
56    /// # Arguments
57    ///
58    /// * `tensor` - The tensor.
59    ///
60    /// # Returns
61    ///
62    /// The device of the tensor.
63    fn int_device(tensor: &IntTensor<B>) -> Device<B>;
64
65    /// Moves the tensor to the given device.
66    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
67
68    /// Reshapes the tensor.
69    ///
70    /// # Arguments
71    ///
72    /// * `tensor` - The tensor.
73    /// * `shape` - The new shape.
74    ///
75    /// # Returns
76    ///
77    /// The tensor with the new shape.
78    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
79
80    /// Gets the element at the given indices.
81    ///
82    /// # Arguments
83    ///
84    /// * `tensor` - The tensor.
85    /// * `slices` - The slices specifying ranges and steps for each dimension.
86    ///
87    /// # Returns
88    ///
89    /// The elements at the given indices.
90    ///
91    /// # Note
92    ///
93    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
94    /// be passed to this method. Backend implementations do not need to handle empty slices.
95    fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;
96
97    /// Sets the values in the tensor for the given ranges.
98    ///
99    /// # Arguments
100    ///
101    /// * `tensor` - The tensor.
102    /// * `ranges` - The ranges to set the values for.
103    ///
104    /// # Returns
105    ///
106    /// The tensor with the values set for the given ranges.
107    ///
108    /// # Note
109    ///
110    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
111    /// high-level tensor API and will not be passed to this method. Backend implementations do
112    /// not need to handle empty slice assignments.
113    fn int_slice_assign(
114        tensor: IntTensor<B>,
115        slices: &[Slice],
116        value: IntTensor<B>,
117    ) -> IntTensor<B>;
118
119    /// Converts int tensor to float tensor.
120    ///
121    /// # Arguments
122    ///
123    /// * `tensor` - The tensor.
124    /// * `out_dtype` - The output tensor dtype.
125    ///
126    /// # Returns
127    ///
128    /// The int tensor with the same data as the float tensor.
129    fn int_into_float(tensor: IntTensor<B>, out_dtype: FloatDType) -> FloatTensor<B>;
130
131    /// Fills the tensor with values from the value tensor if the mask is true at the given
132    /// indices.
133    ///
134    /// # Arguments
135    ///
136    /// * `tensor` - The tensor.
137    /// * `mask` - The mask.
138    /// * `value` - The value tensor.
139    ///
140    /// # Returns
141    ///
142    /// The tensor with the values filled.
143    fn int_mask_where(
144        tensor: IntTensor<B>,
145        mask: BoolTensor<B>,
146        value: IntTensor<B>,
147    ) -> IntTensor<B>;
148
149    /// Fills the tensor with the given value if the mask is true at the given indices.
150    ///
151    /// # Arguments
152    ///
153    /// * `tensor` - The tensor.
154    /// * `mask` - The mask.
155    /// * `value` - The value.
156    ///
157    /// # Returns
158    ///
159    /// The tensor with the values filled.
160    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: Scalar) -> IntTensor<B>;
161
162    /// Gather elements from the tensor at the given indices.
163    ///
164    /// # Arguments
165    ///
166    /// * `dim` - The dimension to gather from.
167    /// * `tensor` - The tensor.
168    /// * `indices` - The indices.
169    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
170
171    /// Scatter a given value to the tensor at the given indices using sum reduction.
172    ///
173    /// # Arguments
174    ///
175    /// * `dim` - The dimension to scatter to.
176    /// * `tensor` - The tensor.
177    /// * `indices` - The indices.
178    /// * `value` - The value.
179    ///
180    /// # Returns
181    ///
182    /// The tensor with the values scattered.
183    fn int_scatter_add(
184        dim: usize,
185        tensor: IntTensor<B>,
186        indices: IntTensor<B>,
187        value: IntTensor<B>,
188    ) -> IntTensor<B>;
189
190    /// Select tensor elements along the given dimension corresponding to the given indices.
191    ///
192    /// # Arguments
193    ///
194    /// * `tensor` - The tensor.
195    /// * `dim` - The dimension to select from.
196    /// * `indices` - The indices.
197    ///
198    /// # Returns
199    ///
200    /// The tensor with the selected elements.
201    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
202
203    /// Assign the selected elements along the given dimension corresponding to the given indices
204    /// to the given value using sum reduction.
205    ///
206    /// # Arguments
207    ///
208    /// * `tensor` - The tensor.
209    /// * `dim` - The dimension to select from.
210    /// * `indices` - The indices.
211    /// * `value` - The value.
212    ///
213    /// # Returns
214    ///
215    /// The tensor with the selected elements assigned to the given value.
216    fn int_select_add(
217        tensor: IntTensor<B>,
218        dim: usize,
219        indices: IntTensor<B>,
220        value: IntTensor<B>,
221    ) -> IntTensor<B>;
222
223    /// Repeats the tensor along the given dimension the given number of times.
224    ///
225    /// # Arguments
226    ///
227    /// * `tensor` - The tensor.
228    /// * `dim` - The dimension to repeat.
229    /// * `times` - The number of times to repeat.
230    ///
231    /// # Returns
232    ///
233    /// The tensor with the given dimension repeated the given number of times.
234    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
235        repeat_with_slice_assign::<B, Int>(tensor, dim, times)
236    }
237
238    /// Concatenates the given tensors along the given dimension.
239    ///
240    /// # Arguments
241    ///
242    /// * `tensors` - The tensors.
243    /// * `dim` - The dimension to concatenate along.
244    ///
245    /// # Returns
246    ///
247    /// The concatenated tensor.
248    ///
249    /// # Note
250    ///
251    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
252    /// high-level tensor API and will not be passed to this method. Backend implementations do
253    /// not need to handle empty tensors.
254    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
255        cat_with_slice_assign::<B, Int>(tensors, dim)
256    }
257
258    /// Element-wise equality comparison.
259    ///
260    /// # Arguments
261    ///
262    /// * `lhs` - The left-hand side tensor.
263    /// * `rhs` - The right-hand side tensor.
264    /// * `out_dtype` - The output tensor dtype.
265    ///
266    /// # Returns
267    ///
268    /// The boolean tensor with the result of the comparison.
269    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
270
271    /// Element-wise non-equality comparison.
272    ///
273    /// # Arguments
274    ///
275    /// * `lhs` - The left-hand side tensor.
276    /// * `rhs` - The right-hand side tensor.
277    /// * `out_dtype` - The output tensor dtype.
278    ///
279    /// # Returns
280    ///
281    /// The boolean tensor with the result of the comparison.
282    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
283        let equal_tensor = B::int_equal(lhs, rhs, out_dtype);
284        B::bool_not(equal_tensor)
285    }
286
287    /// Element-wise equality comparison with a scalar.
288    ///
289    /// # Arguments
290    ///
291    /// * `lhs` - The left-hand side tensor.
292    /// * `rhs` - The right-hand side scalar.
293    /// * `out_dtype` - The output tensor dtype.
294    ///
295    /// # Returns
296    ///
297    /// The boolean tensor with the result of the comparison.
298    fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
299
300    /// Element-wise non-equality comparison with a scalar.
301    ///
302    /// # Arguments
303    ///
304    /// * `lhs` - The left-hand side tensor.
305    /// * `rhs` - The right-hand side scalar.
306    /// * `out_dtype` - The output tensor dtype.
307    ///
308    /// # Returns
309    ///
310    /// The boolean tensor with the result of the comparison.
311    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B> {
312        let equal_tensor = B::int_equal_elem(lhs, rhs, out_dtype);
313        B::bool_not(equal_tensor)
314    }
315
316    /// Element-wise greater than comparison.
317    ///
318    /// # Arguments
319    ///
320    /// * `lhs` - The left-hand side tensor.
321    /// * `rhs` - The right-hand side tensor.
322    /// * `out_dtype` - The output tensor dtype.
323    ///
324    /// # Returns
325    ///
326    /// The boolean tensor with the result of the comparison.
327    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
328
329    /// Element-wise greater than comparison with a scalar.
330    ///
331    /// # Arguments
332    ///
333    /// * `lhs` - The left-hand side tensor.
334    /// * `rhs` - The right-hand side scalar.
335    /// * `out_dtype` - The output tensor dtype.
336    ///
337    /// # Returns
338    ///
339    /// The boolean tensor with the result of the comparison.
340    fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
341
342    /// Element-wise greater than or equal comparison.
343    ///
344    /// # Arguments
345    ///
346    /// * `lhs` - The left-hand side tensor.
347    /// * `rhs` - The right-hand side tensor.
348    /// * `out_dtype` - The output tensor dtype.
349    ///
350    /// # Returns
351    ///
352    /// The boolean tensor with the result of the comparison.
353    fn int_greater_equal(
354        lhs: IntTensor<B>,
355        rhs: IntTensor<B>,
356        out_dtype: BoolDType,
357    ) -> BoolTensor<B>;
358
359    /// Element-wise greater than or equal comparison with a scalar.
360    ///
361    /// # Arguments
362    ///
363    /// * `lhs` - The left-hand side tensor.
364    /// * `rhs` - The right-hand side scalar.
365    /// * `out_dtype` - The output tensor dtype.
366    ///
367    /// # Returns
368    ///
369    /// The boolean tensor with the result of the comparison.
370    fn int_greater_equal_elem(
371        lhs: IntTensor<B>,
372        rhs: Scalar,
373        out_dtype: BoolDType,
374    ) -> BoolTensor<B>;
375
376    /// Element-wise less than comparison.
377    ///
378    /// # Arguments
379    ///
380    /// * `lhs` - The left-hand side tensor.
381    /// * `rhs` - The right-hand side tensor.
382    /// * `out_dtype` - The output tensor dtype.
383    ///
384    /// # Returns
385    ///
386    /// The boolean tensor with the result of the comparison.
387    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
388
389    /// Element-wise less than comparison with a scalar.
390    ///
391    /// # Arguments
392    ///
393    /// * `lhs` - The left-hand side tensor.
394    /// * `rhs` - The right-hand side scalar.
395    /// * `out_dtype` - The output tensor dtype.
396    ///
397    /// # Returns
398    ///
399    /// The boolean tensor with the result of the comparison.
400    fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
401
402    /// Element-wise less than or equal comparison.
403    ///
404    /// # Arguments
405    ///
406    /// * `lhs` - The left-hand side tensor.
407    /// * `rhs` - The right-hand side tensor.
408    /// * `out_dtype` - The output tensor dtype.
409    ///
410    /// # Returns
411    ///
412    /// The boolean tensor with the result of the comparison.
413    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType)
414    -> BoolTensor<B>;
415
416    /// Element-wise less than or equal comparison with a scalar.
417    ///
418    /// # Arguments
419    ///
420    /// * `lhs` - The left-hand side tensor.
421    /// * `rhs` - The right-hand side scalar.
422    /// * `out_dtype` - The output tensor dtype.
423    ///
424    /// # Returns
425    ///
426    /// The boolean tensor with the result of the comparison.
427    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
428
429    // ====  NUMERIC ==== //
430
431    /// Element-wise addition.
432    ///
433    /// # Arguments
434    ///
435    /// * `lhs` - The left-hand side tensor.
436    /// * `rhs` - The right-hand side tensor.
437    ///
438    /// # Returns
439    ///
440    /// The result of the addition.
441    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
442
443    /// Element-wise addition with a scalar.
444    ///
445    /// # Arguments
446    ///
447    /// * `lhs` - The left-hand side tensor.
448    /// * `rhs` - The right-hand side scalar.
449    ///
450    /// # Returns
451    ///
452    /// The result of the addition.
453    fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
454
455    /// Element-wise power with a IntTensor.
456    ///
457    /// # Arguments
458    ///
459    /// * `lhs` - The left-hand side IntTensor.
460    /// * `rhs` - The right-hand side IntTensor.
461    ///
462    /// # Returns
463    ///
464    /// The elements of `lhs` raised to the power of the elements of `rhs`.
465    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
466        let dtype = lhs.dtype();
467        let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
468        B::float_into_int(
469            B::float_powi(B::int_into_float(lhs, float_dtype), rhs),
470            dtype.into(),
471        )
472    }
473
474    /// Element-wise power with a scalar.
475    ///
476    /// # Backend Implementors Note
477    ///
478    /// A number of common exponent cases can be implemented with operations
479    /// which are much cheaper than generic exponentiation.
480    ///
481    /// This (`Backend` impl overridable) operation handles generic optimizations
482    /// for several common integer exponent cases; and then dispatches to
483    /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
484    /// operation to handle the generic case.
485    ///
486    /// # Arguments
487    ///
488    /// * `lhs` - The left-hand side tensor.
489    /// * `rhs` - The right-hand side scalar.
490    ///
491    /// # Returns
492    ///
493    /// The elements of `lhs` raised to the value of `rhs`.
494    fn int_powi_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
495        let exp = rhs.elem::<i32>();
496        match exp {
497            0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
498            1 => lhs,
499            2 => Self::int_mul(lhs.clone(), lhs),
500            _ => Self::int_powi_scalar_impl(lhs, rhs),
501        }
502    }
503
504    /// Element-wise power with a scalar.
505    ///
506    /// # Backend Implementors Note
507    ///
508    /// This is the generic implementation of integer exponentiation
509    /// called by [`Self::int_powi_scalar`] in the fallback case.
510    ///
511    /// By default, this performs a relatively expensive conversion to float,
512    /// exponentiation in float, and conversion back to int.
513    /// This reduces the minimal operation set for `Backend`s,
514    /// at the cost of performance.
515    ///
516    /// This is a good target for specialized optimizations in `Backend` implementations.
517    ///
518    /// As a general rule, this should not be called directly.
519    ///
520    /// # Arguments
521    ///
522    /// * `lhs` - The left-hand side tensor.
523    /// * `rhs` - The right-hand side scalar.
524    ///
525    /// # Returns
526    ///
527    /// The elements of `lhs` raised to the value of `rhs`.
528    fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
529        let dtype = lhs.dtype();
530        let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
531        B::float_into_int(
532            B::float_powi_scalar_impl(B::int_into_float(lhs, float_dtype), rhs),
533            dtype.into(),
534        )
535    }
536
537    /// Clamps a tensor under a minimum value.
538    ///
539    /// # Arguments
540    ///
541    /// * `tensor` - The tensor to clamp.
542    /// * `min` - The minimum value.
543    ///
544    /// # Returns
545    ///
546    /// The clamped tensor.
547    fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
548        let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
549        let mask = Self::int_lower_elem(tensor.clone(), min, dtype);
550        Self::int_mask_fill(tensor, mask, min)
551    }
552
553    /// Clamps a tensor over a maximum value.
554    ///
555    /// # Arguments
556    ///
557    /// * `tensor` - The tensor to clamp.
558    /// * `max` - The maximum value.
559    ///
560    /// # Returns
561    ///
562    /// The clamped tensor.
563    fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
564        let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
565        let mask = Self::int_greater_elem(tensor.clone(), max, dtype);
566        Self::int_mask_fill(tensor, mask, max)
567    }
568
569    /// Clamps a tensor between a minimum and maximum value.
570    ///
571    /// # Arguments
572    ///
573    /// * `tensor` - The tensor to clamp.
574    /// * `min` - The minimum value.
575    /// * `max` - The maximum value.
576    ///
577    /// # Returns
578    ///
579    /// The clamped tensor.
580    fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
581        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
582    }
583
584    /// Element-wise subtraction.
585    ///
586    /// # Arguments
587    ///
588    /// * `lhs` - The left-hand side tensor.
589    /// * `rhs` - The right-hand side tensor.
590    ///
591    /// # Returns
592    ///
593    /// The result of the subtraction.
594    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
595
596    /// Element-wise subtraction with a scalar.
597    ///
598    /// # Arguments
599    ///
600    /// * `lhs` - The left-hand side tensor.
601    /// * `rhs` - The right-hand side scalar.
602    ///
603    /// # Returns
604    ///
605    /// The result of the subtraction.
606    fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
607
608    /// Element-wise multiplication.
609    ///
610    /// # Arguments
611    ///
612    /// * `lhs` - The left-hand side tensor.
613    /// * `rhs` - The right-hand side tensor.
614    ///
615    /// # Returns
616    ///
617    /// The result of the multiplication.
618    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
619
620    /// Element-wise multiplication with a scalar.
621    ///
622    /// # Arguments
623    ///
624    /// * `lhs` - The left-hand side tensor.
625    /// * `rhs` - The right-hand side scalar.
626    ///
627    /// # Returns
628    ///
629    /// The result of the multiplication.
630    fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
631
632    /// Element-wise division.
633    ///
634    /// # Arguments
635    ///
636    /// * `lhs` - The left-hand side tensor.
637    /// * `rhs` - The right-hand side tensor.
638    ///
639    /// # Returns
640    ///
641    /// The result of the division.
642    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
643
644    /// Element-wise division with a scalar.
645    ///
646    /// # Arguments
647    ///
648    /// * `lhs` - The left-hand side tensor.
649    /// * `rhs` - The right-hand side scalar.
650    ///
651    /// # Returns
652    ///
653    /// The result of the division.
654    fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
655
656    /// Element-wise modulus.
657    ///
658    /// # Arguments
659    /// * `lhs` - The left-hand side tensor.
660    /// * `rhs` - The right-hand side scalar.
661    ///
662    /// # Returns
663    ///
664    /// The result of applying the modulus of the scalar to the tensor.
665    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
666
667    /// Element-wise modulus with a scalar.
668    ///
669    /// # Arguments
670    /// * `lhs` - The left-hand side tensor.
671    /// * `rhs` - The right-hand side scalar.
672    ///
673    /// # Returns
674    ///
675    /// The result of applying the modulus of the scalar to the tensor.
676    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
677
678    /// Multiplies two tensors together using matrix multiplication.
679    ///
680    /// # Arguments
681    ///
682    /// * `lhs` - The left-hand side tensor.
683    /// * `rhs` - The right-hand side tensor.
684    ///
685    /// # Returns
686    ///
687    /// The result of multiplying the two tensors together using matrix multiplication.
688    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
689
690    /// Element-wise negation.
691    ///
692    /// # Arguments
693    ///
694    /// * `tensor` - The tensor to negate.
695    ///
696    /// # Returns
697    ///
698    /// The negated tensor.
699    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
700        Self::int_mul_scalar(tensor, (-1).into())
701    }
702
703    /// Creates a tensor of zeros.
704    ///
705    /// # Arguments
706    ///
707    /// * `shape` - The shape of the tensor.
708    /// * `device` - The device to create the tensor on.
709    /// * `dtype` - The target data type.
710    ///
711    /// # Returns
712    ///
713    /// The tensor of zeros.
714    fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
715        Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
716    }
717
718    /// Creates a tensor of ones.
719    ///
720    /// # Arguments
721    ///
722    /// * `shape` - The shape of the tensor.
723    /// * `device` - The device to create the tensor on.
724    /// * `dtype` - The target data type.
725    ///
726    /// # Returns
727    ///
728    /// The tensor of ones.
729    fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
730        Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
731    }
732
733    /// Creates a tensor filled with given value.
734    ///
735    /// # Arguments
736    ///
737    /// * `shape` - The shape of the tensor.
738    /// * `fill_value` - The value with which to fill the tensor.
739    /// * `device` - The device to create the tensor on.
740    /// * `dtype` - The target data type.
741    ///
742    /// # Returns
743    ///
744    /// The tensor filled with given value
745    fn int_full(
746        shape: Shape,
747        fill_value: Scalar,
748        device: &Device<B>,
749        dtype: IntDType,
750    ) -> IntTensor<B> {
751        Self::int_from_data(
752            TensorData::full_dtype(shape, fill_value, dtype.into()),
753            device,
754        )
755    }
756
757    /// Sums all elements in the tensor.
758    ///
759    /// # Arguments
760    ///
761    /// * `tensor` - The tensor to sum.
762    ///
763    /// # Returns
764    ///
765    /// The sum of all elements in the tensor.
766    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
767
768    /// Sums all elements in the tensor along a dimension.
769    ///
770    /// # Arguments
771    ///
772    /// * `tensor` - The tensor to sum.
773    /// * `dim` - The dimension to sum along.
774    ///
775    /// # Returns
776    ///
777    /// The sum of all elements in the tensor along the dimension.
778    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
779
780    /// Computes the product of all elements in the tensor.
781    ///
782    /// # Arguments
783    ///
784    /// * `tensor` - The tensor to compute the product of.
785    ///
786    /// # Returns
787    ///
788    /// The product of all elements in the tensor.
789    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
790
791    /// Computes the product of all elements in the tensor along a dimension.
792    ///
793    /// # Arguments
794    ///
795    /// * `tensor` - The tensor to compute the product of.
796    /// * `dim` - The dimension to compute the product along.
797    ///
798    /// # Returns
799    ///
800    /// The product of all elements in the tensor along the dimension.
801    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
802
803    /// Computes the mean of all elements in the tensor.
804    ///
805    /// # Arguments
806    ///
807    /// * `tensor` - The tensor to compute the mean of.
808    ///
809    /// # Returns
810    ///
811    /// The mean of all elements in the tensor.
812    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
813        let num_elems = tensor.shape().num_elements() as i64;
814        B::int_div_scalar(B::int_sum(tensor), num_elems.into())
815    }
816
817    /// Computes the mean of all elements in the tensor along a dimension.
818    ///
819    /// # Arguments
820    ///
821    /// * `tensor` - The tensor to compute the mean of.
822    ///
823    /// # Returns
824    ///
825    /// The mean of all elements in the tensor along the dimension.
826    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
827
828    /// Computes the cumulative sum of elements along a dimension.
829    ///
830    /// # Arguments
831    ///
832    /// * `tensor` - The tensor to compute the cumulative sum of.
833    /// * `dim` - The dimension along which to compute the cumulative sum.
834    ///
835    /// # Returns
836    ///
837    /// A tensor with the same shape where each element is the cumulative sum
838    /// of all elements up to and including that position along the dimension.
839    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
840
841    /// Computes the cumulative product of elements along a dimension.
842    ///
843    /// # Arguments
844    ///
845    /// * `tensor` - The tensor to compute the cumulative product of.
846    /// * `dim` - The dimension along which to compute the cumulative product.
847    ///
848    /// # Returns
849    ///
850    /// A tensor with the same shape where each element is the cumulative product
851    /// of all elements up to and including that position along the dimension.
852    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
853
854    /// Computes the cumulative minimum of elements along a dimension.
855    ///
856    /// # Arguments
857    ///
858    /// * `tensor` - The tensor to compute the cumulative minimum of.
859    /// * `dim` - The dimension along which to compute the cumulative minimum.
860    ///
861    /// # Returns
862    ///
863    /// A tensor with the same shape where each element is the minimum
864    /// of all elements up to and including that position along the dimension.
865    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
866
867    /// Computes the cumulative maximum of elements along a dimension.
868    ///
869    /// # Arguments
870    ///
871    /// * `tensor` - The tensor to compute the cumulative maximum of.
872    /// * `dim` - The dimension along which to compute the cumulative maximum.
873    ///
874    /// # Returns
875    ///
876    /// A tensor with the same shape where each element is the maximum
877    /// of all elements up to and including that position along the dimension.
878    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
879
880    /// Gets the indices of the maximum elements along a dimension.
881    ///
882    /// # Arguments
883    ///
884    /// * `tensor` - The tensor to get the maximum indices of.
885    /// * `dim` - The dimension to get the maximum indices along.
886    ///
887    /// # Returns
888    ///
889    /// The indices of the maximum elements along the dimension.
890    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
891
892    /// Gets the indices of the minimum elements along a dimension.
893    ///
894    /// # Arguments
895    ///
896    /// * `tensor` - The tensor to get the minimum indices of.
897    /// * `dim` - The dimension to get the minimum indices along.
898    ///
899    /// # Returns
900    ///
901    /// The indices of the minimum elements along the dimension.
902    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
903
904    /// Gets the maximum element in the tensor.
905    ///
906    /// # Arguments
907    ///
908    /// * `tensor` - The tensor to get the maximum element of.
909    ///
910    /// # Returns
911    ///
912    /// The maximum element in the tensor.
913    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
914        let shape = tensor.shape();
915        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
916
917        B::int_max_dim(tensor, 0)
918    }
919
920    /// Gets the maximum element in the tensor along a dimension.
921    ///
922    /// # Arguments
923    ///
924    /// * `tensor` - The tensor to get the maximum element of.
925    /// * `dim` - The dimension to get the maximum element along.
926    ///
927    /// # Returns
928    ///
929    /// The maximum element in the tensor along the dimension.
930    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
931        let index = B::int_argmax(tensor.clone(), dim);
932        B::int_gather(dim, tensor, index)
933    }
934
935    /// Gets the maximum elements and corresponding indices along a dimension.
936    ///
937    /// # Arguments
938    ///
939    /// * `tensor` - The tensor to get the maximum elements and indices of.
940    /// * `dim` - The dimension to get the maximum elements and indices along.
941    ///
942    /// # Returns
943    ///
944    /// The maximum elements and corresponding indices along the dimension.
945    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
946        let index = B::int_argmax(tensor.clone(), dim);
947        let values = B::int_gather(dim, tensor, index.clone());
948
949        (values, index)
950    }
951
952    /// Gets the maximum absolute element in the tensor.
953    ///
954    /// # Arguments
955    ///
956    /// * `tensor` - The tensor to get the maximum element of.
957    ///
958    /// # Returns
959    ///
960    /// The maximum element in the tensor.
961    fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
962        let shape = tensor.shape();
963        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
964
965        B::int_max_abs_dim(tensor, 0)
966    }
967
968    /// Gets the maximum absolute element in the tensor along a dimension.
969    ///
970    /// # Arguments
971    ///
972    /// * `tensor` - The tensor to get the maximum element of.
973    /// * `dim` - The dimension to get the maximum element along.
974    ///
975    /// # Returns
976    ///
977    /// The maximum element in the tensor along the dimension.
978    fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
979        B::int_max_dim(B::int_abs(tensor), dim)
980    }
981
982    /// Gets the minimum element in the tensor.
983    ///
984    /// # Arguments
985    ///
986    /// * `tensor` - The tensor to get the minimum element of.
987    ///
988    /// # Returns
989    ///
990    /// The minimum element in the tensor.
991    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
992        let shape = tensor.shape();
993        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
994
995        B::int_min_dim(tensor, 0)
996    }
997
998    /// Gets the minimum elements in the tensor along a dimension.
999    ///
1000    /// # Arguments
1001    ///
1002    /// * `tensor` - The tensor to get the minimum element of.
1003    /// * `dim` - The dimension to get the minimum element along.
1004    ///
1005    /// # Returns
1006    ///
1007    /// The minimum element in the tensor along the dimension.
1008    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1009        let index = B::int_argmin(tensor.clone(), dim);
1010        B::int_gather(dim, tensor, index)
1011    }
1012
1013    /// Gets the minimum elements and corresponding indices along a dimension.
1014    ///
1015    /// # Arguments
1016    ///
1017    /// * `tensor` - The tensor to get the minimum elements and indices of.
1018    /// * `dim` - The dimension to get the minimum elements and indices along.
1019    ///
1020    /// # Returns
1021    ///
1022    /// The minimum elements and corresponding indices along the dimension.
1023    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1024        let indices = B::int_argmin(tensor.clone(), dim);
1025        let values = B::int_gather(dim, tensor, indices.clone());
1026
1027        (values, indices)
1028    }
1029
1030    /// Returns a new tensor with absolute values.
1031    ///
1032    /// # Arguments
1033    ///
1034    /// * `tensor` - The tensor to take absolute value of.
1035    ///
1036    /// # Returns
1037    ///
1038    /// A tensor with the same shape as `tensor` with absolute values.
1039    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1040
1041    /// Transposes an int tensor.
1042    ///
1043    /// # Arguments
1044    ///
1045    /// * `tensor` - The tensor to transpose.
1046    ///
1047    /// # Returns
1048    ///
1049    /// The transposed tensor.
1050    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1051        let ndims = tensor.shape().num_dims();
1052        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1053    }
1054
1055    /// Swaps two dimensions of an int tensor.
1056    ///
1057    /// # Arguments
1058    ///
1059    /// * `tensor` - The tensor to swap the dimensions of.
1060    /// * `dim1` - The first dimension to swap.
1061    /// * `dim2` - The second dimension to swap.
1062    ///
1063    /// # Returns
1064    ///
1065    /// The tensor with the dimensions swapped.
1066    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1067
1068    /// Permutes the dimensions of a tensor.
1069    ///
1070    /// # Arguments
1071    ///
1072    /// * `tensor` - The tensor to permute the dimensions of.
1073    /// * `axes` - The new order of the dimensions.
1074    /// # Returns
1075    ///
1076    /// The tensor with the dimensions permuted.
1077    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1078
1079    /// Reverse the order of elements in a tensor along the given axes.
1080    ///
1081    /// # Arguments
1082    ///
1083    /// * `tensor` - The tensor to reverse.
1084    /// * `axes` - The axes to reverse.
1085    ///
1086    /// The tensor with the elements reversed.
1087    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1088
1089    /// Creates a new int tensor with random values.
1090    ///
1091    ///  # Arguments
1092    ///  * `shape` - The shape of the tensor.
1093    ///  * `distribution` - The distribution to sample from.
1094    ///  * `device` - The device to create the tensor on.
1095    /// * `dtype` - The target data type.
1096    ///
1097    ///  # Returns
1098    ///
1099    ///  The tensor with the given shape and random values.
1100    fn int_random(
1101        shape: Shape,
1102        distribution: Distribution,
1103        device: &Device<B>,
1104        dtype: IntDType,
1105    ) -> IntTensor<B>;
1106
1107    /// Creates a new tensor with values from the given range with the given step size.
1108    ///
1109    /// # Arguments
1110    ///
1111    /// * `range` - The range of values.
1112    /// * `step` - The step size.
1113    /// * `device` - The device to create the tensor on.
1114    /// * `dtype` - The target data type.
1115    ///
1116    /// # Returns
1117    ///
1118    /// The tensor with the given values.
1119    fn int_arange_step(
1120        range: Range<i64>,
1121        step: usize,
1122        device: &Device<B>,
1123        dtype: IntDType,
1124    ) -> IntTensor<B> {
1125        let value = range
1126            .step_by(step)
1127            .map(|i| i.elem())
1128            .collect::<Vec<IntElem<B>>>();
1129        let shape = Shape::new([value.len()]);
1130        let data = TensorData::new(value, shape).convert_dtype(dtype.into());
1131        B::int_from_data(data, device)
1132    }
1133
1134    /// Creates a new tensor with values from the given range.
1135    ///
1136    /// # Arguments
1137    ///
1138    /// * `range` - The range of values.
1139    /// * `device` - The device to create the tensor on.
1140    ///
1141    /// # Returns
1142    ///
1143    /// The tensor with the given values.
1144    ///
1145    /// # Remarks
1146    ///
1147    /// Uses `arange_step` with a step size of 1 under the hood.
1148    fn int_arange(range: Range<i64>, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
1149        Self::int_arange_step(range, 1, device, dtype)
1150    }
1151
1152    /// Tests if any element in the int `tensor` evaluates to True.
1153    ///
1154    /// # Arguments
1155    ///
1156    /// * `tensor` - The tensor to test.
1157    ///
1158    /// # Returns
1159    ///
1160    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1161    fn int_any(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1162        let int_dtype = tensor.dtype();
1163        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1164        let bool_tensor = B::bool_not(bool_tensor);
1165        let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1166        B::int_greater_elem(sum, 0.into(), out_dtype)
1167    }
1168
1169    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1170    ///
1171    /// # Arguments
1172    ///
1173    /// * `tensor` - The tensor to test.
1174    /// * `dim` - The axis along which to test.
1175    ///
1176    /// # Returns
1177    ///
1178    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1179    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1180    /// evaluates to True, False otherwise.
1181    fn int_any_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1182        let int_dtype = tensor.dtype();
1183        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1184        let bool_tensor = B::bool_not(bool_tensor);
1185        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1186        B::int_greater_elem(sum, 0.into(), out_dtype)
1187    }
1188
1189    /// Tests if all elements in the int `tensor` evaluate to True.
1190    ///
1191    /// # Arguments
1192    ///
1193    /// * `tensor` - The tensor to test.
1194    /// * `out_dtype` - The output tensor dtype.
1195    ///
1196    /// # Returns
1197    ///
1198    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1199    /// evaluate to True, False otherwise.
1200    fn int_all(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1201        let int_dtype = tensor.dtype();
1202        let num_elems = tensor.shape().num_elements() as i64;
1203        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1204        let bool_tensor = B::bool_not(bool_tensor);
1205        let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1206        B::int_equal_elem(sum, num_elems.into(), out_dtype)
1207    }
1208
1209    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1210    ///
1211    /// # Arguments
1212    ///
1213    /// * `tensor` - The tensor to test.
1214    /// * `dim` - The axis along which to test.
1215    /// * `out_dtype` - The output tensor dtype.
1216    ///
1217    /// # Returns
1218    ///
1219    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1220    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1221    /// evaluates to True, False otherwise.
1222    fn int_all_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1223        let int_dtype = tensor.dtype();
1224        let num_elems = tensor.shape()[dim] as i64;
1225        let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1226        let bool_tensor = B::bool_not(bool_tensor);
1227        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1228        B::int_equal_elem(sum, num_elems.into(), out_dtype)
1229    }
1230
1231    /// Returns the signs of the int `tensor`.
1232    ///
1233    /// # Arguments
1234    ///
1235    /// * `tensor` - The tensor to extract the signs from.
1236    ///
1237    /// # Returns
1238    ///
1239    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1240    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1241        let dtype = tensor.dtype();
1242        let device = B::int_device(&tensor);
1243        let bool_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
1244        let zeros = B::int_zeros(tensor.shape(), &device, dtype.into());
1245        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.into(), bool_dtype);
1246        let greater_than_zero = B::int_greater_elem(tensor, 0.into(), bool_dtype);
1247
1248        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1).into());
1249        result = B::int_mask_fill(result, greater_than_zero, 1.into());
1250        result
1251    }
1252
1253    /// Broadcasts the int `tensor` to the given `shape`.
1254    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1255
1256    /// Sort the elements of the input `tensor` by value along a given dimension.
1257    ///
1258    /// This sort is unstable (i.e., may reorder equal elements).
1259    ///
1260    /// # Arguments
1261    ///
1262    /// * `tensor` - The input tensor.
1263    /// * `dim` - The axis along which to sort.
1264    /// * `descending` - The sorting order.
1265    ///
1266    /// # Returns
1267    ///
1268    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1269    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1270        sort::<B, Int>(tensor, dim, descending)
1271    }
1272
1273    /// Sort the elements of the input `tensor` by value along a given dimension.
1274    ///
1275    /// This sort is unstable (i.e., may reorder equal elements).
1276    ///
1277    /// # Arguments
1278    ///
1279    /// * `tensor` - The input tensor.
1280    /// * `dim` - The axis along which to sort.
1281    ///
1282    /// # Returns
1283    ///
1284    /// A tensor with the same shape as the input tensor and corresponding indices, where
1285    /// the elements are sorted by value and the indices map back to the original input tensor.
1286    fn int_sort_with_indices(
1287        tensor: IntTensor<B>,
1288        dim: usize,
1289        descending: bool,
1290    ) -> (IntTensor<B>, IntTensor<B>) {
1291        let dtype = tensor.dtype();
1292        sort_with_indices::<B, Int>(tensor, dim, descending, dtype.into())
1293    }
1294
1295    /// Returns the indices that sort the elements of the input `tensor` by value
1296    /// along a given dimension.
1297    ///
1298    /// This sort is unstable (i.e., may reorder equal elements).
1299    ///
1300    /// # Arguments
1301    ///
1302    /// * `tensor` - The input tensor.
1303    /// * `dim` - The axis along which to sort.
1304    /// * `descending` - The sorting order.
1305    ///
1306    /// # Returns
1307    ///
1308    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1309    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1310        let dtype = tensor.dtype();
1311        argsort::<B, Int>(tensor, dim, descending, dtype.into())
1312    }
1313
1314    /// Bitwise AND operation for Int Tensors
1315    fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1316
1317    /// Bitwise AND operation for Int Tensors with a scalar
1318    fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1319
1320    /// Bitwise OR operation for Int Tensors
1321    fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1322
1323    /// Bitwise OR operation for Int Tensors with a scalar
1324    fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1325
1326    /// Bitwise XOR operation for Int Tensors
1327    fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1328
1329    /// Bitwise XOR operation for Int Tensors with a scalar
1330    fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1331
1332    /// Bitwise NOT operation for Int Tensors
1333    fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1334
1335    /// Bitwise left shift operation for Int Tensors
1336    fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1337
1338    /// Bitwise left shift operation for Int Tensors with a scalar
1339    fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1340
1341    /// Bitwise right shift operation for Int Tensors
1342    fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1343
1344    /// Bitwise right shift operation for Int Tensors with a scalar
1345    fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1346
1347    /// Converts a tensor to another integer data type.
1348    ///
1349    /// # Arguments
1350    ///
1351    /// * `tensor` - The tensor to convert.
1352    /// * `dtype` - The target data type.
1353    ///
1354    /// # Returns
1355    ///
1356    /// A tensor with the same values as `tensor` but in the target integer data type.
1357    fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1358
1359    /// Unfold windows along a dimension.
1360    ///
1361    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1362    /// where windows are advanced by `step` at each index.
1363    ///
1364    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1365    ///
1366    /// # Arguments
1367    ///
1368    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1369    /// * `dim` - the selected dim.
1370    /// * `size` - the size of each unfolded window.
1371    /// * `step` - the step between each window.
1372    ///
1373    /// # Returns
1374    ///
1375    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1376    fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1377}