burn_backend/backend/ops/int_tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::sort::{argsort, sort, sort_with_indices};
4use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};
5use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};
6use crate::{ExecutionError, Scalar, get_device_settings};
7use alloc::vec::Vec;
8use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16 /// Creates a new int tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 /// * `dtype` - The target data type.
23 ///
24 /// # Returns
25 ///
26 /// The integer tensor with the given shape.
27 fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
28
29 /// Converts the tensor to a data structure.
30 ///
31 /// # Arguments
32 ///
33 /// * `tensor` - The tensor.
34 ///
35 /// # Returns
36 ///
37 /// The data structure with the tensor's data.
38 fn int_into_data(
39 tensor: IntTensor<B>,
40 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
41
42 /// Creates a tensor from the data structure.
43 ///
44 /// # Arguments
45 ///
46 /// * `data` - The data structure.
47 /// * `device` - The device to create the tensor on.
48 ///
49 /// # Returns
50 ///
51 /// The tensor with the data.
52 fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
53
54 /// Gets the device of the tensor.
55 ///
56 /// # Arguments
57 ///
58 /// * `tensor` - The tensor.
59 ///
60 /// # Returns
61 ///
62 /// The device of the tensor.
63 fn int_device(tensor: &IntTensor<B>) -> Device<B>;
64
65 /// Moves the tensor to the given device.
66 fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
67
68 /// Reshapes the tensor.
69 ///
70 /// # Arguments
71 ///
72 /// * `tensor` - The tensor.
73 /// * `shape` - The new shape.
74 ///
75 /// # Returns
76 ///
77 /// The tensor with the new shape.
78 fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
79
80 /// Gets the element at the given indices.
81 ///
82 /// # Arguments
83 ///
84 /// * `tensor` - The tensor.
85 /// * `slices` - The slices specifying ranges and steps for each dimension.
86 ///
87 /// # Returns
88 ///
89 /// The elements at the given indices.
90 ///
91 /// # Note
92 ///
93 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
94 /// be passed to this method. Backend implementations do not need to handle empty slices.
95 fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;
96
97 /// Sets the values in the tensor for the given ranges.
98 ///
99 /// # Arguments
100 ///
101 /// * `tensor` - The tensor.
102 /// * `ranges` - The ranges to set the values for.
103 ///
104 /// # Returns
105 ///
106 /// The tensor with the values set for the given ranges.
107 ///
108 /// # Note
109 ///
110 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
111 /// high-level tensor API and will not be passed to this method. Backend implementations do
112 /// not need to handle empty slice assignments.
113 fn int_slice_assign(
114 tensor: IntTensor<B>,
115 slices: &[Slice],
116 value: IntTensor<B>,
117 ) -> IntTensor<B>;
118
119 /// Converts int tensor to float tensor.
120 ///
121 /// # Arguments
122 ///
123 /// * `tensor` - The tensor.
124 /// * `out_dtype` - The output tensor dtype.
125 ///
126 /// # Returns
127 ///
128 /// The int tensor with the same data as the float tensor.
129 fn int_into_float(tensor: IntTensor<B>, out_dtype: FloatDType) -> FloatTensor<B>;
130
131 /// Fills the tensor with values from the value tensor if the mask is true at the given
132 /// indices.
133 ///
134 /// # Arguments
135 ///
136 /// * `tensor` - The tensor.
137 /// * `mask` - The mask.
138 /// * `value` - The value tensor.
139 ///
140 /// # Returns
141 ///
142 /// The tensor with the values filled.
143 fn int_mask_where(
144 tensor: IntTensor<B>,
145 mask: BoolTensor<B>,
146 value: IntTensor<B>,
147 ) -> IntTensor<B>;
148
149 /// Fills the tensor with the given value if the mask is true at the given indices.
150 ///
151 /// # Arguments
152 ///
153 /// * `tensor` - The tensor.
154 /// * `mask` - The mask.
155 /// * `value` - The value.
156 ///
157 /// # Returns
158 ///
159 /// The tensor with the values filled.
160 fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: Scalar) -> IntTensor<B>;
161
162 /// Gather elements from the tensor at the given indices.
163 ///
164 /// # Arguments
165 ///
166 /// * `dim` - The dimension to gather from.
167 /// * `tensor` - The tensor.
168 /// * `indices` - The indices.
169 fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
170
171 /// Scatter a given value to the tensor at the given indices using sum reduction.
172 ///
173 /// # Arguments
174 ///
175 /// * `dim` - The dimension to scatter to.
176 /// * `tensor` - The tensor.
177 /// * `indices` - The indices.
178 /// * `value` - The value.
179 ///
180 /// # Returns
181 ///
182 /// The tensor with the values scattered.
183 fn int_scatter_add(
184 dim: usize,
185 tensor: IntTensor<B>,
186 indices: IntTensor<B>,
187 value: IntTensor<B>,
188 ) -> IntTensor<B>;
189
190 /// Select tensor elements along the given dimension corresponding to the given indices.
191 ///
192 /// # Arguments
193 ///
194 /// * `tensor` - The tensor.
195 /// * `dim` - The dimension to select from.
196 /// * `indices` - The indices.
197 ///
198 /// # Returns
199 ///
200 /// The tensor with the selected elements.
201 fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
202
203 /// Assign the selected elements along the given dimension corresponding to the given indices
204 /// to the given value using sum reduction.
205 ///
206 /// # Arguments
207 ///
208 /// * `tensor` - The tensor.
209 /// * `dim` - The dimension to select from.
210 /// * `indices` - The indices.
211 /// * `value` - The value.
212 ///
213 /// # Returns
214 ///
215 /// The tensor with the selected elements assigned to the given value.
216 fn int_select_add(
217 tensor: IntTensor<B>,
218 dim: usize,
219 indices: IntTensor<B>,
220 value: IntTensor<B>,
221 ) -> IntTensor<B>;
222
223 /// Repeats the tensor along the given dimension the given number of times.
224 ///
225 /// # Arguments
226 ///
227 /// * `tensor` - The tensor.
228 /// * `dim` - The dimension to repeat.
229 /// * `times` - The number of times to repeat.
230 ///
231 /// # Returns
232 ///
233 /// The tensor with the given dimension repeated the given number of times.
234 fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
235 repeat_with_slice_assign::<B, Int>(tensor, dim, times)
236 }
237
238 /// Concatenates the given tensors along the given dimension.
239 ///
240 /// # Arguments
241 ///
242 /// * `tensors` - The tensors.
243 /// * `dim` - The dimension to concatenate along.
244 ///
245 /// # Returns
246 ///
247 /// The concatenated tensor.
248 ///
249 /// # Note
250 ///
251 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
252 /// high-level tensor API and will not be passed to this method. Backend implementations do
253 /// not need to handle empty tensors.
254 fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
255 cat_with_slice_assign::<B, Int>(tensors, dim)
256 }
257
258 /// Element-wise equality comparison.
259 ///
260 /// # Arguments
261 ///
262 /// * `lhs` - The left-hand side tensor.
263 /// * `rhs` - The right-hand side tensor.
264 /// * `out_dtype` - The output tensor dtype.
265 ///
266 /// # Returns
267 ///
268 /// The boolean tensor with the result of the comparison.
269 fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
270
271 /// Element-wise non-equality comparison.
272 ///
273 /// # Arguments
274 ///
275 /// * `lhs` - The left-hand side tensor.
276 /// * `rhs` - The right-hand side tensor.
277 /// * `out_dtype` - The output tensor dtype.
278 ///
279 /// # Returns
280 ///
281 /// The boolean tensor with the result of the comparison.
282 fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
283 let equal_tensor = B::int_equal(lhs, rhs, out_dtype);
284 B::bool_not(equal_tensor)
285 }
286
287 /// Element-wise equality comparison with a scalar.
288 ///
289 /// # Arguments
290 ///
291 /// * `lhs` - The left-hand side tensor.
292 /// * `rhs` - The right-hand side scalar.
293 /// * `out_dtype` - The output tensor dtype.
294 ///
295 /// # Returns
296 ///
297 /// The boolean tensor with the result of the comparison.
298 fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
299
300 /// Element-wise non-equality comparison with a scalar.
301 ///
302 /// # Arguments
303 ///
304 /// * `lhs` - The left-hand side tensor.
305 /// * `rhs` - The right-hand side scalar.
306 /// * `out_dtype` - The output tensor dtype.
307 ///
308 /// # Returns
309 ///
310 /// The boolean tensor with the result of the comparison.
311 fn int_not_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B> {
312 let equal_tensor = B::int_equal_elem(lhs, rhs, out_dtype);
313 B::bool_not(equal_tensor)
314 }
315
316 /// Element-wise greater than comparison.
317 ///
318 /// # Arguments
319 ///
320 /// * `lhs` - The left-hand side tensor.
321 /// * `rhs` - The right-hand side tensor.
322 /// * `out_dtype` - The output tensor dtype.
323 ///
324 /// # Returns
325 ///
326 /// The boolean tensor with the result of the comparison.
327 fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
328
329 /// Element-wise greater than comparison with a scalar.
330 ///
331 /// # Arguments
332 ///
333 /// * `lhs` - The left-hand side tensor.
334 /// * `rhs` - The right-hand side scalar.
335 /// * `out_dtype` - The output tensor dtype.
336 ///
337 /// # Returns
338 ///
339 /// The boolean tensor with the result of the comparison.
340 fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
341
342 /// Element-wise greater than or equal comparison.
343 ///
344 /// # Arguments
345 ///
346 /// * `lhs` - The left-hand side tensor.
347 /// * `rhs` - The right-hand side tensor.
348 /// * `out_dtype` - The output tensor dtype.
349 ///
350 /// # Returns
351 ///
352 /// The boolean tensor with the result of the comparison.
353 fn int_greater_equal(
354 lhs: IntTensor<B>,
355 rhs: IntTensor<B>,
356 out_dtype: BoolDType,
357 ) -> BoolTensor<B>;
358
359 /// Element-wise greater than or equal comparison with a scalar.
360 ///
361 /// # Arguments
362 ///
363 /// * `lhs` - The left-hand side tensor.
364 /// * `rhs` - The right-hand side scalar.
365 /// * `out_dtype` - The output tensor dtype.
366 ///
367 /// # Returns
368 ///
369 /// The boolean tensor with the result of the comparison.
370 fn int_greater_equal_elem(
371 lhs: IntTensor<B>,
372 rhs: Scalar,
373 out_dtype: BoolDType,
374 ) -> BoolTensor<B>;
375
376 /// Element-wise less than comparison.
377 ///
378 /// # Arguments
379 ///
380 /// * `lhs` - The left-hand side tensor.
381 /// * `rhs` - The right-hand side tensor.
382 /// * `out_dtype` - The output tensor dtype.
383 ///
384 /// # Returns
385 ///
386 /// The boolean tensor with the result of the comparison.
387 fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
388
389 /// Element-wise less than comparison with a scalar.
390 ///
391 /// # Arguments
392 ///
393 /// * `lhs` - The left-hand side tensor.
394 /// * `rhs` - The right-hand side scalar.
395 /// * `out_dtype` - The output tensor dtype.
396 ///
397 /// # Returns
398 ///
399 /// The boolean tensor with the result of the comparison.
400 fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
401
402 /// Element-wise less than or equal comparison.
403 ///
404 /// # Arguments
405 ///
406 /// * `lhs` - The left-hand side tensor.
407 /// * `rhs` - The right-hand side tensor.
408 /// * `out_dtype` - The output tensor dtype.
409 ///
410 /// # Returns
411 ///
412 /// The boolean tensor with the result of the comparison.
413 fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType)
414 -> BoolTensor<B>;
415
416 /// Element-wise less than or equal comparison with a scalar.
417 ///
418 /// # Arguments
419 ///
420 /// * `lhs` - The left-hand side tensor.
421 /// * `rhs` - The right-hand side scalar.
422 /// * `out_dtype` - The output tensor dtype.
423 ///
424 /// # Returns
425 ///
426 /// The boolean tensor with the result of the comparison.
427 fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
428
429 // ==== NUMERIC ==== //
430
431 /// Element-wise addition.
432 ///
433 /// # Arguments
434 ///
435 /// * `lhs` - The left-hand side tensor.
436 /// * `rhs` - The right-hand side tensor.
437 ///
438 /// # Returns
439 ///
440 /// The result of the addition.
441 fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
442
443 /// Element-wise addition with a scalar.
444 ///
445 /// # Arguments
446 ///
447 /// * `lhs` - The left-hand side tensor.
448 /// * `rhs` - The right-hand side scalar.
449 ///
450 /// # Returns
451 ///
452 /// The result of the addition.
453 fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
454
455 /// Element-wise power with a IntTensor.
456 ///
457 /// # Arguments
458 ///
459 /// * `lhs` - The left-hand side IntTensor.
460 /// * `rhs` - The right-hand side IntTensor.
461 ///
462 /// # Returns
463 ///
464 /// The elements of `lhs` raised to the power of the elements of `rhs`.
465 fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
466 let dtype = lhs.dtype();
467 let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
468 B::float_into_int(
469 B::float_powi(B::int_into_float(lhs, float_dtype), rhs),
470 dtype.into(),
471 )
472 }
473
474 /// Element-wise power with a scalar.
475 ///
476 /// # Backend Implementors Note
477 ///
478 /// A number of common exponent cases can be implemented with operations
479 /// which are much cheaper than generic exponentiation.
480 ///
481 /// This (`Backend` impl overridable) operation handles generic optimizations
482 /// for several common integer exponent cases; and then dispatches to
483 /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
484 /// operation to handle the generic case.
485 ///
486 /// # Arguments
487 ///
488 /// * `lhs` - The left-hand side tensor.
489 /// * `rhs` - The right-hand side scalar.
490 ///
491 /// # Returns
492 ///
493 /// The elements of `lhs` raised to the value of `rhs`.
494 fn int_powi_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
495 let exp = rhs.elem::<i32>();
496 match exp {
497 0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
498 1 => lhs,
499 2 => Self::int_mul(lhs.clone(), lhs),
500 _ => Self::int_powi_scalar_impl(lhs, rhs),
501 }
502 }
503
504 /// Element-wise power with a scalar.
505 ///
506 /// # Backend Implementors Note
507 ///
508 /// This is the generic implementation of integer exponentiation
509 /// called by [`Self::int_powi_scalar`] in the fallback case.
510 ///
511 /// By default, this performs a relatively expensive conversion to float,
512 /// exponentiation in float, and conversion back to int.
513 /// This reduces the minimal operation set for `Backend`s,
514 /// at the cost of performance.
515 ///
516 /// This is a good target for specialized optimizations in `Backend` implementations.
517 ///
518 /// As a general rule, this should not be called directly.
519 ///
520 /// # Arguments
521 ///
522 /// * `lhs` - The left-hand side tensor.
523 /// * `rhs` - The right-hand side scalar.
524 ///
525 /// # Returns
526 ///
527 /// The elements of `lhs` raised to the value of `rhs`.
528 fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
529 let dtype = lhs.dtype();
530 let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
531 B::float_into_int(
532 B::float_powi_scalar_impl(B::int_into_float(lhs, float_dtype), rhs),
533 dtype.into(),
534 )
535 }
536
537 /// Clamps a tensor under a minimum value.
538 ///
539 /// # Arguments
540 ///
541 /// * `tensor` - The tensor to clamp.
542 /// * `min` - The minimum value.
543 ///
544 /// # Returns
545 ///
546 /// The clamped tensor.
547 fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
548 let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
549 let mask = Self::int_lower_elem(tensor.clone(), min, dtype);
550 Self::int_mask_fill(tensor, mask, min)
551 }
552
553 /// Clamps a tensor over a maximum value.
554 ///
555 /// # Arguments
556 ///
557 /// * `tensor` - The tensor to clamp.
558 /// * `max` - The maximum value.
559 ///
560 /// # Returns
561 ///
562 /// The clamped tensor.
563 fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
564 let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
565 let mask = Self::int_greater_elem(tensor.clone(), max, dtype);
566 Self::int_mask_fill(tensor, mask, max)
567 }
568
569 /// Clamps a tensor between a minimum and maximum value.
570 ///
571 /// # Arguments
572 ///
573 /// * `tensor` - The tensor to clamp.
574 /// * `min` - The minimum value.
575 /// * `max` - The maximum value.
576 ///
577 /// # Returns
578 ///
579 /// The clamped tensor.
580 fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
581 Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
582 }
583
584 /// Element-wise subtraction.
585 ///
586 /// # Arguments
587 ///
588 /// * `lhs` - The left-hand side tensor.
589 /// * `rhs` - The right-hand side tensor.
590 ///
591 /// # Returns
592 ///
593 /// The result of the subtraction.
594 fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
595
596 /// Element-wise subtraction with a scalar.
597 ///
598 /// # Arguments
599 ///
600 /// * `lhs` - The left-hand side tensor.
601 /// * `rhs` - The right-hand side scalar.
602 ///
603 /// # Returns
604 ///
605 /// The result of the subtraction.
606 fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
607
608 /// Element-wise multiplication.
609 ///
610 /// # Arguments
611 ///
612 /// * `lhs` - The left-hand side tensor.
613 /// * `rhs` - The right-hand side tensor.
614 ///
615 /// # Returns
616 ///
617 /// The result of the multiplication.
618 fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
619
620 /// Element-wise multiplication with a scalar.
621 ///
622 /// # Arguments
623 ///
624 /// * `lhs` - The left-hand side tensor.
625 /// * `rhs` - The right-hand side scalar.
626 ///
627 /// # Returns
628 ///
629 /// The result of the multiplication.
630 fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
631
632 /// Element-wise division.
633 ///
634 /// # Arguments
635 ///
636 /// * `lhs` - The left-hand side tensor.
637 /// * `rhs` - The right-hand side tensor.
638 ///
639 /// # Returns
640 ///
641 /// The result of the division.
642 fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
643
644 /// Element-wise division with a scalar.
645 ///
646 /// # Arguments
647 ///
648 /// * `lhs` - The left-hand side tensor.
649 /// * `rhs` - The right-hand side scalar.
650 ///
651 /// # Returns
652 ///
653 /// The result of the division.
654 fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
655
656 /// Element-wise modulus.
657 ///
658 /// # Arguments
659 /// * `lhs` - The left-hand side tensor.
660 /// * `rhs` - The right-hand side scalar.
661 ///
662 /// # Returns
663 ///
664 /// The result of applying the modulus of the scalar to the tensor.
665 fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
666
667 /// Element-wise modulus with a scalar.
668 ///
669 /// # Arguments
670 /// * `lhs` - The left-hand side tensor.
671 /// * `rhs` - The right-hand side scalar.
672 ///
673 /// # Returns
674 ///
675 /// The result of applying the modulus of the scalar to the tensor.
676 fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
677
678 /// Multiplies two tensors together using matrix multiplication.
679 ///
680 /// # Arguments
681 ///
682 /// * `lhs` - The left-hand side tensor.
683 /// * `rhs` - The right-hand side tensor.
684 ///
685 /// # Returns
686 ///
687 /// The result of multiplying the two tensors together using matrix multiplication.
688 fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
689
690 /// Element-wise negation.
691 ///
692 /// # Arguments
693 ///
694 /// * `tensor` - The tensor to negate.
695 ///
696 /// # Returns
697 ///
698 /// The negated tensor.
699 fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
700 Self::int_mul_scalar(tensor, (-1).into())
701 }
702
703 /// Creates a tensor of zeros.
704 ///
705 /// # Arguments
706 ///
707 /// * `shape` - The shape of the tensor.
708 /// * `device` - The device to create the tensor on.
709 /// * `dtype` - The target data type.
710 ///
711 /// # Returns
712 ///
713 /// The tensor of zeros.
714 fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
715 Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
716 }
717
718 /// Creates a tensor of ones.
719 ///
720 /// # Arguments
721 ///
722 /// * `shape` - The shape of the tensor.
723 /// * `device` - The device to create the tensor on.
724 /// * `dtype` - The target data type.
725 ///
726 /// # Returns
727 ///
728 /// The tensor of ones.
729 fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
730 Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
731 }
732
733 /// Creates a tensor filled with given value.
734 ///
735 /// # Arguments
736 ///
737 /// * `shape` - The shape of the tensor.
738 /// * `fill_value` - The value with which to fill the tensor.
739 /// * `device` - The device to create the tensor on.
740 /// * `dtype` - The target data type.
741 ///
742 /// # Returns
743 ///
744 /// The tensor filled with given value
745 fn int_full(
746 shape: Shape,
747 fill_value: Scalar,
748 device: &Device<B>,
749 dtype: IntDType,
750 ) -> IntTensor<B> {
751 Self::int_from_data(
752 TensorData::full_dtype(shape, fill_value, dtype.into()),
753 device,
754 )
755 }
756
757 /// Sums all elements in the tensor.
758 ///
759 /// # Arguments
760 ///
761 /// * `tensor` - The tensor to sum.
762 ///
763 /// # Returns
764 ///
765 /// The sum of all elements in the tensor.
766 fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
767
768 /// Sums all elements in the tensor along a dimension.
769 ///
770 /// # Arguments
771 ///
772 /// * `tensor` - The tensor to sum.
773 /// * `dim` - The dimension to sum along.
774 ///
775 /// # Returns
776 ///
777 /// The sum of all elements in the tensor along the dimension.
778 fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
779
780 /// Computes the product of all elements in the tensor.
781 ///
782 /// # Arguments
783 ///
784 /// * `tensor` - The tensor to compute the product of.
785 ///
786 /// # Returns
787 ///
788 /// The product of all elements in the tensor.
789 fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
790
791 /// Computes the product of all elements in the tensor along a dimension.
792 ///
793 /// # Arguments
794 ///
795 /// * `tensor` - The tensor to compute the product of.
796 /// * `dim` - The dimension to compute the product along.
797 ///
798 /// # Returns
799 ///
800 /// The product of all elements in the tensor along the dimension.
801 fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
802
803 /// Computes the mean of all elements in the tensor.
804 ///
805 /// # Arguments
806 ///
807 /// * `tensor` - The tensor to compute the mean of.
808 ///
809 /// # Returns
810 ///
811 /// The mean of all elements in the tensor.
812 fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
813 let num_elems = tensor.shape().num_elements() as i64;
814 B::int_div_scalar(B::int_sum(tensor), num_elems.into())
815 }
816
817 /// Computes the mean of all elements in the tensor along a dimension.
818 ///
819 /// # Arguments
820 ///
821 /// * `tensor` - The tensor to compute the mean of.
822 ///
823 /// # Returns
824 ///
825 /// The mean of all elements in the tensor along the dimension.
826 fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
827
828 /// Computes the cumulative sum of elements along a dimension.
829 ///
830 /// # Arguments
831 ///
832 /// * `tensor` - The tensor to compute the cumulative sum of.
833 /// * `dim` - The dimension along which to compute the cumulative sum.
834 ///
835 /// # Returns
836 ///
837 /// A tensor with the same shape where each element is the cumulative sum
838 /// of all elements up to and including that position along the dimension.
839 fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
840
841 /// Computes the cumulative product of elements along a dimension.
842 ///
843 /// # Arguments
844 ///
845 /// * `tensor` - The tensor to compute the cumulative product of.
846 /// * `dim` - The dimension along which to compute the cumulative product.
847 ///
848 /// # Returns
849 ///
850 /// A tensor with the same shape where each element is the cumulative product
851 /// of all elements up to and including that position along the dimension.
852 fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
853
854 /// Computes the cumulative minimum of elements along a dimension.
855 ///
856 /// # Arguments
857 ///
858 /// * `tensor` - The tensor to compute the cumulative minimum of.
859 /// * `dim` - The dimension along which to compute the cumulative minimum.
860 ///
861 /// # Returns
862 ///
863 /// A tensor with the same shape where each element is the minimum
864 /// of all elements up to and including that position along the dimension.
865 fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
866
867 /// Computes the cumulative maximum of elements along a dimension.
868 ///
869 /// # Arguments
870 ///
871 /// * `tensor` - The tensor to compute the cumulative maximum of.
872 /// * `dim` - The dimension along which to compute the cumulative maximum.
873 ///
874 /// # Returns
875 ///
876 /// A tensor with the same shape where each element is the maximum
877 /// of all elements up to and including that position along the dimension.
878 fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
879
880 /// Gets the indices of the maximum elements along a dimension.
881 ///
882 /// # Arguments
883 ///
884 /// * `tensor` - The tensor to get the maximum indices of.
885 /// * `dim` - The dimension to get the maximum indices along.
886 ///
887 /// # Returns
888 ///
889 /// The indices of the maximum elements along the dimension.
890 fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
891
892 /// Gets the indices of the minimum elements along a dimension.
893 ///
894 /// # Arguments
895 ///
896 /// * `tensor` - The tensor to get the minimum indices of.
897 /// * `dim` - The dimension to get the minimum indices along.
898 ///
899 /// # Returns
900 ///
901 /// The indices of the minimum elements along the dimension.
902 fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
903
904 /// Gets the maximum element in the tensor.
905 ///
906 /// # Arguments
907 ///
908 /// * `tensor` - The tensor to get the maximum element of.
909 ///
910 /// # Returns
911 ///
912 /// The maximum element in the tensor.
913 fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
914 let shape = tensor.shape();
915 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
916
917 B::int_max_dim(tensor, 0)
918 }
919
920 /// Gets the maximum element in the tensor along a dimension.
921 ///
922 /// # Arguments
923 ///
924 /// * `tensor` - The tensor to get the maximum element of.
925 /// * `dim` - The dimension to get the maximum element along.
926 ///
927 /// # Returns
928 ///
929 /// The maximum element in the tensor along the dimension.
930 fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
931 let index = B::int_argmax(tensor.clone(), dim);
932 B::int_gather(dim, tensor, index)
933 }
934
935 /// Gets the maximum elements and corresponding indices along a dimension.
936 ///
937 /// # Arguments
938 ///
939 /// * `tensor` - The tensor to get the maximum elements and indices of.
940 /// * `dim` - The dimension to get the maximum elements and indices along.
941 ///
942 /// # Returns
943 ///
944 /// The maximum elements and corresponding indices along the dimension.
945 fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
946 let index = B::int_argmax(tensor.clone(), dim);
947 let values = B::int_gather(dim, tensor, index.clone());
948
949 (values, index)
950 }
951
952 /// Gets the maximum absolute element in the tensor.
953 ///
954 /// # Arguments
955 ///
956 /// * `tensor` - The tensor to get the maximum element of.
957 ///
958 /// # Returns
959 ///
960 /// The maximum element in the tensor.
961 fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
962 let shape = tensor.shape();
963 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
964
965 B::int_max_abs_dim(tensor, 0)
966 }
967
968 /// Gets the maximum absolute element in the tensor along a dimension.
969 ///
970 /// # Arguments
971 ///
972 /// * `tensor` - The tensor to get the maximum element of.
973 /// * `dim` - The dimension to get the maximum element along.
974 ///
975 /// # Returns
976 ///
977 /// The maximum element in the tensor along the dimension.
978 fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
979 B::int_max_dim(B::int_abs(tensor), dim)
980 }
981
982 /// Gets the minimum element in the tensor.
983 ///
984 /// # Arguments
985 ///
986 /// * `tensor` - The tensor to get the minimum element of.
987 ///
988 /// # Returns
989 ///
990 /// The minimum element in the tensor.
991 fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
992 let shape = tensor.shape();
993 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
994
995 B::int_min_dim(tensor, 0)
996 }
997
998 /// Gets the minimum elements in the tensor along a dimension.
999 ///
1000 /// # Arguments
1001 ///
1002 /// * `tensor` - The tensor to get the minimum element of.
1003 /// * `dim` - The dimension to get the minimum element along.
1004 ///
1005 /// # Returns
1006 ///
1007 /// The minimum element in the tensor along the dimension.
1008 fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1009 let index = B::int_argmin(tensor.clone(), dim);
1010 B::int_gather(dim, tensor, index)
1011 }
1012
1013 /// Gets the minimum elements and corresponding indices along a dimension.
1014 ///
1015 /// # Arguments
1016 ///
1017 /// * `tensor` - The tensor to get the minimum elements and indices of.
1018 /// * `dim` - The dimension to get the minimum elements and indices along.
1019 ///
1020 /// # Returns
1021 ///
1022 /// The minimum elements and corresponding indices along the dimension.
1023 fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1024 let indices = B::int_argmin(tensor.clone(), dim);
1025 let values = B::int_gather(dim, tensor, indices.clone());
1026
1027 (values, indices)
1028 }
1029
1030 /// Returns a new tensor with absolute values.
1031 ///
1032 /// # Arguments
1033 ///
1034 /// * `tensor` - The tensor to take absolute value of.
1035 ///
1036 /// # Returns
1037 ///
1038 /// A tensor with the same shape as `tensor` with absolute values.
1039 fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1040
1041 /// Transposes an int tensor.
1042 ///
1043 /// # Arguments
1044 ///
1045 /// * `tensor` - The tensor to transpose.
1046 ///
1047 /// # Returns
1048 ///
1049 /// The transposed tensor.
1050 fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1051 let ndims = tensor.shape().num_dims();
1052 Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1053 }
1054
1055 /// Swaps two dimensions of an int tensor.
1056 ///
1057 /// # Arguments
1058 ///
1059 /// * `tensor` - The tensor to swap the dimensions of.
1060 /// * `dim1` - The first dimension to swap.
1061 /// * `dim2` - The second dimension to swap.
1062 ///
1063 /// # Returns
1064 ///
1065 /// The tensor with the dimensions swapped.
1066 fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1067
1068 /// Permutes the dimensions of a tensor.
1069 ///
1070 /// # Arguments
1071 ///
1072 /// * `tensor` - The tensor to permute the dimensions of.
1073 /// * `axes` - The new order of the dimensions.
1074 /// # Returns
1075 ///
1076 /// The tensor with the dimensions permuted.
1077 fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1078
1079 /// Reverse the order of elements in a tensor along the given axes.
1080 ///
1081 /// # Arguments
1082 ///
1083 /// * `tensor` - The tensor to reverse.
1084 /// * `axes` - The axes to reverse.
1085 ///
1086 /// The tensor with the elements reversed.
1087 fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1088
1089 /// Creates a new int tensor with random values.
1090 ///
1091 /// # Arguments
1092 /// * `shape` - The shape of the tensor.
1093 /// * `distribution` - The distribution to sample from.
1094 /// * `device` - The device to create the tensor on.
1095 /// * `dtype` - The target data type.
1096 ///
1097 /// # Returns
1098 ///
1099 /// The tensor with the given shape and random values.
1100 fn int_random(
1101 shape: Shape,
1102 distribution: Distribution,
1103 device: &Device<B>,
1104 dtype: IntDType,
1105 ) -> IntTensor<B>;
1106
1107 /// Creates a new tensor with values from the given range with the given step size.
1108 ///
1109 /// # Arguments
1110 ///
1111 /// * `range` - The range of values.
1112 /// * `step` - The step size.
1113 /// * `device` - The device to create the tensor on.
1114 /// * `dtype` - The target data type.
1115 ///
1116 /// # Returns
1117 ///
1118 /// The tensor with the given values.
1119 fn int_arange_step(
1120 range: Range<i64>,
1121 step: usize,
1122 device: &Device<B>,
1123 dtype: IntDType,
1124 ) -> IntTensor<B> {
1125 let value = range
1126 .step_by(step)
1127 .map(|i| i.elem())
1128 .collect::<Vec<IntElem<B>>>();
1129 let shape = Shape::new([value.len()]);
1130 let data = TensorData::new(value, shape).convert_dtype(dtype.into());
1131 B::int_from_data(data, device)
1132 }
1133
1134 /// Creates a new tensor with values from the given range.
1135 ///
1136 /// # Arguments
1137 ///
1138 /// * `range` - The range of values.
1139 /// * `device` - The device to create the tensor on.
1140 ///
1141 /// # Returns
1142 ///
1143 /// The tensor with the given values.
1144 ///
1145 /// # Remarks
1146 ///
1147 /// Uses `arange_step` with a step size of 1 under the hood.
1148 fn int_arange(range: Range<i64>, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
1149 Self::int_arange_step(range, 1, device, dtype)
1150 }
1151
1152 /// Tests if any element in the int `tensor` evaluates to True.
1153 ///
1154 /// # Arguments
1155 ///
1156 /// * `tensor` - The tensor to test.
1157 ///
1158 /// # Returns
1159 ///
1160 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1161 fn int_any(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1162 let int_dtype = tensor.dtype();
1163 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1164 let bool_tensor = B::bool_not(bool_tensor);
1165 let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1166 B::int_greater_elem(sum, 0.into(), out_dtype)
1167 }
1168
1169 /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1170 ///
1171 /// # Arguments
1172 ///
1173 /// * `tensor` - The tensor to test.
1174 /// * `dim` - The axis along which to test.
1175 ///
1176 /// # Returns
1177 ///
1178 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1179 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1180 /// evaluates to True, False otherwise.
1181 fn int_any_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1182 let int_dtype = tensor.dtype();
1183 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1184 let bool_tensor = B::bool_not(bool_tensor);
1185 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1186 B::int_greater_elem(sum, 0.into(), out_dtype)
1187 }
1188
1189 /// Tests if all elements in the int `tensor` evaluate to True.
1190 ///
1191 /// # Arguments
1192 ///
1193 /// * `tensor` - The tensor to test.
1194 /// * `out_dtype` - The output tensor dtype.
1195 ///
1196 /// # Returns
1197 ///
1198 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1199 /// evaluate to True, False otherwise.
1200 fn int_all(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1201 let int_dtype = tensor.dtype();
1202 let num_elems = tensor.shape().num_elements() as i64;
1203 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1204 let bool_tensor = B::bool_not(bool_tensor);
1205 let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1206 B::int_equal_elem(sum, num_elems.into(), out_dtype)
1207 }
1208
1209 /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1210 ///
1211 /// # Arguments
1212 ///
1213 /// * `tensor` - The tensor to test.
1214 /// * `dim` - The axis along which to test.
1215 /// * `out_dtype` - The output tensor dtype.
1216 ///
1217 /// # Returns
1218 ///
1219 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1220 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1221 /// evaluates to True, False otherwise.
1222 fn int_all_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1223 let int_dtype = tensor.dtype();
1224 let num_elems = tensor.shape()[dim] as i64;
1225 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1226 let bool_tensor = B::bool_not(bool_tensor);
1227 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1228 B::int_equal_elem(sum, num_elems.into(), out_dtype)
1229 }
1230
1231 /// Returns the signs of the int `tensor`.
1232 ///
1233 /// # Arguments
1234 ///
1235 /// * `tensor` - The tensor to extract the signs from.
1236 ///
1237 /// # Returns
1238 ///
1239 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1240 fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1241 let dtype = tensor.dtype();
1242 let device = B::int_device(&tensor);
1243 let bool_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
1244 let zeros = B::int_zeros(tensor.shape(), &device, dtype.into());
1245 let less_than_zero = B::int_lower_elem(tensor.clone(), 0.into(), bool_dtype);
1246 let greater_than_zero = B::int_greater_elem(tensor, 0.into(), bool_dtype);
1247
1248 let mut result = B::int_mask_fill(zeros, less_than_zero, (-1).into());
1249 result = B::int_mask_fill(result, greater_than_zero, 1.into());
1250 result
1251 }
1252
1253 /// Broadcasts the int `tensor` to the given `shape`.
1254 fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1255
1256 /// Sort the elements of the input `tensor` by value along a given dimension.
1257 ///
1258 /// This sort is unstable (i.e., may reorder equal elements).
1259 ///
1260 /// # Arguments
1261 ///
1262 /// * `tensor` - The input tensor.
1263 /// * `dim` - The axis along which to sort.
1264 /// * `descending` - The sorting order.
1265 ///
1266 /// # Returns
1267 ///
1268 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1269 fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1270 sort::<B, Int>(tensor, dim, descending)
1271 }
1272
1273 /// Sort the elements of the input `tensor` by value along a given dimension.
1274 ///
1275 /// This sort is unstable (i.e., may reorder equal elements).
1276 ///
1277 /// # Arguments
1278 ///
1279 /// * `tensor` - The input tensor.
1280 /// * `dim` - The axis along which to sort.
1281 ///
1282 /// # Returns
1283 ///
1284 /// A tensor with the same shape as the input tensor and corresponding indices, where
1285 /// the elements are sorted by value and the indices map back to the original input tensor.
1286 fn int_sort_with_indices(
1287 tensor: IntTensor<B>,
1288 dim: usize,
1289 descending: bool,
1290 ) -> (IntTensor<B>, IntTensor<B>) {
1291 let dtype = tensor.dtype();
1292 sort_with_indices::<B, Int>(tensor, dim, descending, dtype.into())
1293 }
1294
1295 /// Returns the indices that sort the elements of the input `tensor` by value
1296 /// along a given dimension.
1297 ///
1298 /// This sort is unstable (i.e., may reorder equal elements).
1299 ///
1300 /// # Arguments
1301 ///
1302 /// * `tensor` - The input tensor.
1303 /// * `dim` - The axis along which to sort.
1304 /// * `descending` - The sorting order.
1305 ///
1306 /// # Returns
1307 ///
1308 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1309 fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1310 let dtype = tensor.dtype();
1311 argsort::<B, Int>(tensor, dim, descending, dtype.into())
1312 }
1313
1314 /// Bitwise AND operation for Int Tensors
1315 fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1316
1317 /// Bitwise AND operation for Int Tensors with a scalar
1318 fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1319
1320 /// Bitwise OR operation for Int Tensors
1321 fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1322
1323 /// Bitwise OR operation for Int Tensors with a scalar
1324 fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1325
1326 /// Bitwise XOR operation for Int Tensors
1327 fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1328
1329 /// Bitwise XOR operation for Int Tensors with a scalar
1330 fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1331
1332 /// Bitwise NOT operation for Int Tensors
1333 fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1334
1335 /// Bitwise left shift operation for Int Tensors
1336 fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1337
1338 /// Bitwise left shift operation for Int Tensors with a scalar
1339 fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1340
1341 /// Bitwise right shift operation for Int Tensors
1342 fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1343
1344 /// Bitwise right shift operation for Int Tensors with a scalar
1345 fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1346
1347 /// Converts a tensor to another integer data type.
1348 ///
1349 /// # Arguments
1350 ///
1351 /// * `tensor` - The tensor to convert.
1352 /// * `dtype` - The target data type.
1353 ///
1354 /// # Returns
1355 ///
1356 /// A tensor with the same values as `tensor` but in the target integer data type.
1357 fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1358
1359 /// Unfold windows along a dimension.
1360 ///
1361 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1362 /// where windows are advanced by `step` at each index.
1363 ///
1364 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1365 ///
1366 /// # Arguments
1367 ///
1368 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1369 /// * `dim` - the selected dim.
1370 /// * `size` - the size of each unfolded window.
1371 /// * `step` - the step between each window.
1372 ///
1373 /// # Returns
1374 ///
1375 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1376 fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1377}