burn_tensor/tensor/ops/int_tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
4use crate::{
5 Distribution, ElementConversion, Int, IntDType, TensorData, backend::Backend, tensor::Shape,
6};
7use crate::{TensorMetadata, argsort, sort, sort_with_indices};
8use alloc::vec::Vec;
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
12/// for documentation on each function.
13pub trait IntTensorOps<B: Backend> {
14 /// Creates a new int tensor.
15 ///
16 /// # Arguments
17 ///
18 /// * `shape` - The shape of the tensor.
19 /// * `device` - The device to create the tensor on.
20 /// * `dtype` - The target data type.
21 ///
22 /// # Returns
23 ///
24 /// The integer tensor with the given shape.
25 fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
26
27 /// Converts the tensor to a data structure.
28 ///
29 /// # Arguments
30 ///
31 /// * `tensor` - The tensor.
32 ///
33 /// # Returns
34 ///
35 /// The data structure with the tensor's data.
36 fn int_into_data(tensor: IntTensor<B>) -> impl Future<Output = TensorData> + Send;
37
38 /// Creates a tensor from the data structure.
39 ///
40 /// # Arguments
41 ///
42 /// * `data` - The data structure.
43 /// * `device` - The device to create the tensor on.
44 ///
45 /// # Returns
46 ///
47 /// The tensor with the data.
48 fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
49
50 /// Gets the device of the tensor.
51 ///
52 /// # Arguments
53 ///
54 /// * `tensor` - The tensor.
55 ///
56 /// # Returns
57 ///
58 /// The device of the tensor.
59 fn int_device(tensor: &IntTensor<B>) -> Device<B>;
60
61 /// Moves the tensor to the given device.
62 fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
63
64 /// Reshapes the tensor.
65 ///
66 /// # Arguments
67 ///
68 /// * `tensor` - The tensor.
69 /// * `shape` - The new shape.
70 ///
71 /// # Returns
72 ///
73 /// The tensor with the new shape.
74 fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
75
76 /// Gets the element at the given indices.
77 ///
78 /// # Arguments
79 ///
80 /// * `tensor` - The tensor.
81 /// * `slices` - The slices specifying ranges and steps for each dimension.
82 ///
83 /// # Returns
84 ///
85 /// The elements at the given indices.
86 fn int_slice(tensor: IntTensor<B>, slices: &[crate::Slice]) -> IntTensor<B>;
87
88 /// Sets the values in the tensor for the given ranges.
89 ///
90 /// # Arguments
91 ///
92 /// * `tensor` - The tensor.
93 /// * `ranges` - The ranges to set the values for.
94 ///
95 /// # Returns
96 ///
97 /// The tensor with the values set for the given ranges.
98 fn int_slice_assign(
99 tensor: IntTensor<B>,
100 slices: &[crate::Slice],
101 value: IntTensor<B>,
102 ) -> IntTensor<B>;
103
104 /// Converts int tensor to float tensor.
105 ///
106 /// # Arguments
107 ///
108 /// * `tensor` - The tensor.
109 ///
110 /// # Returns
111 ///
112 /// The int tensor with the same data as the float tensor.
113 fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
114
115 /// Fills the tensor with values from the source tensor if the mask is true at the given
116 /// indices.
117 ///
118 /// # Arguments
119 ///
120 /// * `tensor` - The tensor.
121 /// * `mask` - The mask.
122 /// * `source` - The source tensor.
123 ///
124 /// # Returns
125 ///
126 /// The tensor with the values filled.
127 fn int_mask_where(
128 tensor: IntTensor<B>,
129 mask: BoolTensor<B>,
130 source: IntTensor<B>,
131 ) -> IntTensor<B>;
132
133 /// Fills the tensor with the given value if the mask is true at the given indices.
134 ///
135 /// # Arguments
136 ///
137 /// * `tensor` - The tensor.
138 /// * `mask` - The mask.
139 /// * `value` - The value.
140 ///
141 /// # Returns
142 ///
143 /// The tensor with the values filled.
144 fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
145
146 /// Gather elements from the tensor at the given indices.
147 ///
148 /// # Arguments
149 ///
150 /// * `dim` - The dimension to gather from.
151 /// * `tensor` - The tensor.
152 /// * `indices` - The indices.
153 fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
154
155 /// Scatter a given value to the tensor at the given indices.
156 ///
157 /// # Arguments
158 ///
159 /// * `dim` - The dimension to scatter to.
160 /// * `tensor` - The tensor.
161 /// * `indices` - The indices.
162 /// * `value` - The value.
163 ///
164 /// # Returns
165 ///
166 /// The tensor with the values scattered.
167 fn int_scatter(
168 dim: usize,
169 tensor: IntTensor<B>,
170 indices: IntTensor<B>,
171 value: IntTensor<B>,
172 ) -> IntTensor<B>;
173
174 /// Select tensor elements along the given dimension corresponding to the given indices.
175 ///
176 /// # Arguments
177 ///
178 /// * `tensor` - The tensor.
179 /// * `dim` - The dimension to select from.
180 /// * `indices` - The indices.
181 ///
182 /// # Returns
183 ///
184 /// The tensor with the selected elements.
185 fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
186
187 /// Assign the selected elements along the given dimension corresponding to the given indices
188 /// to the given value.
189 ///
190 /// # Arguments
191 ///
192 /// * `tensor` - The tensor.
193 /// * `dim` - The dimension to select from.
194 /// * `indices` - The indices.
195 /// * `value` - The value.
196 ///
197 /// # Returns
198 ///
199 /// The tensor with the selected elements assigned to the given value.
200 fn int_select_assign(
201 tensor: IntTensor<B>,
202 dim: usize,
203 indices: IntTensor<B>,
204 value: IntTensor<B>,
205 ) -> IntTensor<B>;
206
207 /// Repeats the tensor along the given dimension the given number of times.
208 ///
209 /// # Arguments
210 ///
211 /// * `tensor` - The tensor.
212 /// * `dim` - The dimension to repeat.
213 /// * `times` - The number of times to repeat.
214 ///
215 /// # Returns
216 ///
217 /// The tensor with the given dimension repeated the given number of times.
218 fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
219 repeat_with_slice_assign::<B, Int>(tensor, dim, times)
220 }
221
222 /// Concatenates the given tensors along the given dimension.
223 ///
224 /// # Arguments
225 ///
226 /// * `tensors` - The tensors.
227 /// * `dim` - The dimension to concatenate along.
228 ///
229 /// # Returns
230 ///
231 /// The concatenated tensor.
232 fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
233 cat_with_slice_assign::<B, Int>(tensors, dim)
234 }
235
236 /// Element-wise equality comparison.
237 ///
238 /// # Arguments
239 ///
240 /// * `lhs` - The left-hand side tensor.
241 /// * `rhs` - The right-hand side tensor.
242 ///
243 /// # Returns
244 ///
245 /// The boolean tensor with the result of the comparison.
246 fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
247
248 /// Element-wise non-equality comparison.
249 ///
250 /// # Arguments
251 ///
252 /// * `lhs` - The left-hand side tensor.
253 /// * `rhs` - The right-hand side tensor.
254 ///
255 /// # Returns
256 ///
257 /// The boolean tensor with the result of the comparison.
258 fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
259 let equal_tensor = B::int_equal(lhs, rhs);
260 B::bool_not(equal_tensor)
261 }
262
263 /// Element-wise equality comparison with a scalar.
264 ///
265 /// # Arguments
266 ///
267 /// * `lhs` - The left-hand side tensor.
268 /// * `rhs` - The right-hand side scalar.
269 ///
270 /// # Returns
271 ///
272 /// The boolean tensor with the result of the comparison.
273 fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
274
275 /// Element-wise non-equality comparison with a scalar.
276 ///
277 /// # Arguments
278 ///
279 /// * `lhs` - The left-hand side tensor.
280 /// * `rhs` - The right-hand side scalar.
281 ///
282 /// # Returns
283 ///
284 /// The boolean tensor with the result of the comparison.
285 fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
286 let equal_tensor = B::int_equal_elem(lhs, rhs);
287 B::bool_not(equal_tensor)
288 }
289
290 /// Element-wise greater than comparison.
291 ///
292 /// # Arguments
293 ///
294 /// * `lhs` - The left-hand side tensor.
295 /// * `rhs` - The right-hand side tensor.
296 ///
297 /// # Returns
298 ///
299 /// The boolean tensor with the result of the comparison.
300 fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
301
302 /// Element-wise greater than comparison with a scalar.
303 ///
304 /// # Arguments
305 ///
306 /// * `lhs` - The left-hand side tensor.
307 /// * `rhs` - The right-hand side scalar.
308 ///
309 /// # Returns
310 ///
311 /// The boolean tensor with the result of the comparison.
312 fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
313
314 /// Element-wise greater than or equal comparison.
315 ///
316 /// # Arguments
317 ///
318 /// * `lhs` - The left-hand side tensor.
319 /// * `rhs` - The right-hand side tensor.
320 ///
321 /// # Returns
322 ///
323 /// The boolean tensor with the result of the comparison.
324 fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
325
326 /// Element-wise greater than or equal comparison with a scalar.
327 ///
328 /// # Arguments
329 ///
330 /// * `lhs` - The left-hand side tensor.
331 /// * `rhs` - The right-hand side scalar.
332 ///
333 /// # Returns
334 ///
335 /// The boolean tensor with the result of the comparison.
336 fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
337
338 /// Element-wise less than comparison.
339 ///
340 /// # Arguments
341 ///
342 /// * `lhs` - The left-hand side tensor.
343 /// * `rhs` - The right-hand side tensor.
344 ///
345 /// # Returns
346 ///
347 /// The boolean tensor with the result of the comparison.
348 fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
349
350 /// Element-wise less than comparison with a scalar.
351 ///
352 /// # Arguments
353 ///
354 /// * `lhs` - The left-hand side tensor.
355 /// * `rhs` - The right-hand side scalar.
356 ///
357 /// # Returns
358 ///
359 /// The boolean tensor with the result of the comparison.
360 fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
361
362 /// Element-wise less than or equal comparison.
363 ///
364 /// # Arguments
365 ///
366 /// * `lhs` - The left-hand side tensor.
367 /// * `rhs` - The right-hand side tensor.
368 ///
369 /// # Returns
370 ///
371 /// The boolean tensor with the result of the comparison.
372 fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
373
374 /// Element-wise less than or equal comparison with a scalar.
375 ///
376 /// # Arguments
377 ///
378 /// * `lhs` - The left-hand side tensor.
379 /// * `rhs` - The right-hand side scalar.
380 ///
381 /// # Returns
382 ///
383 /// The boolean tensor with the result of the comparison.
384 fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
385
386 // ==== NUMERIC ==== //
387
388 /// Element-wise addition.
389 ///
390 /// # Arguments
391 ///
392 /// * `lhs` - The left-hand side tensor.
393 /// * `rhs` - The right-hand side tensor.
394 ///
395 /// # Returns
396 ///
397 /// The result of the addition.
398 fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
399
400 /// Element-wise addition with a scalar.
401 ///
402 /// # Arguments
403 ///
404 /// * `lhs` - The left-hand side tensor.
405 /// * `rhs` - The right-hand side scalar.
406 ///
407 /// # Returns
408 ///
409 /// The result of the addition.
410 fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
411
412 /// Element-wise power with a IntTensor.
413 ///
414 /// # Arguments
415 ///
416 /// * `lhs` - The left-hand side IntTensor.
417 /// * `rhs` - The right-hand side IntTensor.
418 ///
419 /// # Returns
420 ///
421 /// The elements of `lhs` raised to the power of the elements of `rhs`.
422 fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
423 B::float_into_int(B::float_powi(B::int_into_float(lhs), rhs))
424 }
425
426 /// Element-wise power with a floatTensor.
427 ///
428 /// # Arguments
429 ///
430 /// * `lhs` - The left-hand side tensor.
431 /// * `rhs` - The right-hand side floatTensor.
432 ///
433 /// # Returns
434 ///
435 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
436 fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
437 B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
438 }
439
440 /// Element-wise power with a scalar.
441 ///
442 /// # Backend Implementors Note
443 ///
444 /// A number of common exponent cases can be implemented with operations
445 /// which are much cheaper than generic exponentiation.
446 ///
447 /// This (`Backend` impl overridable) operation handles generic optimizations
448 /// for several common integer exponent cases; and then dispatches to
449 /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
450 /// operation to handle the generic case.
451 ///
452 /// # Arguments
453 ///
454 /// * `lhs` - The left-hand side tensor.
455 /// * `rhs` - The right-hand side scalar.
456 ///
457 /// # Returns
458 ///
459 /// The elements of `lhs` raised to the value of `rhs`.
460 fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
461 let exp = rhs.elem::<i32>();
462 match exp {
463 0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
464 1 => lhs,
465 2 => Self::int_mul(lhs.clone(), lhs),
466 _ => Self::int_powi_scalar_impl(lhs, rhs),
467 }
468 }
469
470 /// Element-wise power with a scalar.
471 ///
472 /// # Backend Implementors Note
473 ///
474 /// This is the generic implementation of integer exponentiation
475 /// called by [`Self::int_powi_scalar`] in the fallback case.
476 ///
477 /// By default, this performs a relatively expensive conversion to float,
478 /// exponentiation in float, and conversion back to int.
479 /// This reduces the minimal operation set for `Backend`s,
480 /// at the cost of performance.
481 ///
482 /// This is a good target for specialized optimizations in `Backend` implementations.
483 ///
484 /// As a general rule, this should not be called directly.
485 ///
486 /// # Arguments
487 ///
488 /// * `lhs` - The left-hand side tensor.
489 /// * `rhs` - The right-hand side scalar.
490 ///
491 /// # Returns
492 ///
493 /// The elements of `lhs` raised to the value of `rhs`.
494 fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
495 B::float_into_int(B::float_powi_scalar_impl(B::int_into_float(lhs), rhs))
496 }
497
498 /// Element-wise power with a floatTensor.
499 ///
500 /// Handles a number of special cases, then calls [`Self::int_powf_scalar_impl`].
501 ///
502 /// # Arguments
503 ///
504 /// * `lhs` - The left-hand side tensor.
505 /// * `rhs` - The right-hand side scalar.
506 ///
507 /// # Returns
508 ///
509 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
510 fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
511 if num_traits::Float::floor(rhs) == rhs {
512 let exp = B::IntElem::from_elem(rhs as i32);
513 Self::int_powi_scalar(lhs, exp)
514 } else {
515 Self::int_powf_scalar_impl(lhs, rhs)
516 }
517 }
518
519 /// Element-wise power with a floatTensor.
520 ///
521 /// Fallback handler for [`Self::int_powf_scalar`].
522 ///
523 /// # Arguments
524 ///
525 /// * `lhs` - The left-hand side tensor.
526 /// * `rhs` - The right-hand side scalar.
527 ///
528 /// # Returns
529 ///
530 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
531 fn int_powf_scalar_impl(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
532 B::float_into_int(B::float_powf_scalar_impl(B::int_into_float(lhs), rhs))
533 }
534
535 /// Clamps a tensor under a minimum value.
536 ///
537 /// # Arguments
538 ///
539 /// * `tensor` - The tensor to clamp.
540 /// * `min` - The minimum value.
541 ///
542 /// # Returns
543 ///
544 /// The clamped tensor.
545 fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
546 let mask = Self::int_lower_elem(tensor.clone(), min);
547 Self::int_mask_fill(tensor, mask, min)
548 }
549
550 /// Clamps a tensor over a maximum value.
551 ///
552 /// # Arguments
553 ///
554 /// * `tensor` - The tensor to clamp.
555 /// * `max` - The maximum value.
556 ///
557 /// # Returns
558 ///
559 /// The clamped tensor.
560 fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
561 let mask = Self::int_greater_elem(tensor.clone(), max);
562 Self::int_mask_fill(tensor, mask, max)
563 }
564
565 /// Clamps a tensor between a minimum and maximum value.
566 ///
567 /// # Arguments
568 ///
569 /// * `tensor` - The tensor to clamp.
570 /// * `min` - The minimum value.
571 /// * `max` - The maximum value.
572 ///
573 /// # Returns
574 ///
575 /// The clamped tensor.
576 fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
577 Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
578 }
579
580 /// Element-wise subtraction.
581 ///
582 /// # Arguments
583 ///
584 /// * `lhs` - The left-hand side tensor.
585 /// * `rhs` - The right-hand side tensor.
586 ///
587 /// # Returns
588 ///
589 /// The result of the subtraction.
590 fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
591
592 /// Element-wise subtraction with a scalar.
593 ///
594 /// # Arguments
595 ///
596 /// * `lhs` - The left-hand side tensor.
597 /// * `rhs` - The right-hand side scalar.
598 ///
599 /// # Returns
600 ///
601 /// The result of the subtraction.
602 fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
603
604 /// Element-wise multiplication.
605 ///
606 /// # Arguments
607 ///
608 /// * `lhs` - The left-hand side tensor.
609 /// * `rhs` - The right-hand side tensor.
610 ///
611 /// # Returns
612 ///
613 /// The result of the multiplication.
614 fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
615
616 /// Element-wise multiplication with a scalar.
617 ///
618 /// # Arguments
619 ///
620 /// * `lhs` - The left-hand side tensor.
621 /// * `rhs` - The right-hand side scalar.
622 ///
623 /// # Returns
624 ///
625 /// The result of the multiplication.
626 fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
627
628 /// Element-wise division.
629 ///
630 /// # Arguments
631 ///
632 /// * `lhs` - The left-hand side tensor.
633 /// * `rhs` - The right-hand side tensor.
634 ///
635 /// # Returns
636 ///
637 /// The result of the division.
638 fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
639
640 /// Element-wise division with a scalar.
641 ///
642 /// # Arguments
643 ///
644 /// * `lhs` - The left-hand side tensor.
645 /// * `rhs` - The right-hand side scalar.
646 ///
647 /// # Returns
648 ///
649 /// The result of the division.
650 fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
651
652 /// Element-wise modulus.
653 ///
654 /// # Arguments
655 /// * `lhs` - The left-hand side tensor.
656 /// * `rhs` - The right-hand side scalar.
657 ///
658 /// # Returns
659 ///
660 /// The result of applying the modulus of the scalar to the tensor.
661 fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
662
663 /// Element-wise modulus with a scalar.
664 ///
665 /// # Arguments
666 /// * `lhs` - The left-hand side tensor.
667 /// * `rhs` - The right-hand side scalar.
668 ///
669 /// # Returns
670 ///
671 /// The result of applying the modulus of the scalar to the tensor.
672 fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
673
674 /// Multiplies two tensors together using matrix multiplication.
675 ///
676 /// # Arguments
677 ///
678 /// * `lhs` - The left-hand side tensor.
679 /// * `rhs` - The right-hand side tensor.
680 ///
681 /// # Returns
682 ///
683 /// The result of multiplying the two tensors together using matrix multiplication.
684 fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
685
686 /// Element-wise negation.
687 ///
688 /// # Arguments
689 ///
690 /// * `tensor` - The tensor to negate.
691 ///
692 /// # Returns
693 ///
694 /// The negated tensor.
695 fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
696 Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
697 }
698
699 /// Creates a tensor of zeros.
700 ///
701 /// # Arguments
702 ///
703 /// * `shape` - The shape of the tensor.
704 /// * `device` - The device to create the tensor on.
705 /// * `dtype` - The target data type.
706 ///
707 /// # Returns
708 ///
709 /// The tensor of zeros.
710 fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
711 Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
712 }
713
714 /// Creates a tensor of ones.
715 ///
716 /// # Arguments
717 ///
718 /// * `shape` - The shape of the tensor.
719 /// * `device` - The device to create the tensor on.
720 /// * `dtype` - The target data type.
721 ///
722 /// # Returns
723 ///
724 /// The tensor of ones.
725 fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
726 Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
727 }
728
729 /// Creates a tensor filled with given value.
730 ///
731 /// # Arguments
732 ///
733 /// * `shape` - The shape of the tensor.
734 /// * `fill_value` - The value with which to fill the tensor.
735 /// * `device` - The device to create the tensor on.
736 /// * `dtype` - The target data type.
737 ///
738 /// # Returns
739 ///
740 /// The tensor filled with given value
741 fn int_full(
742 shape: Shape,
743 fill_value: IntElem<B>,
744 device: &Device<B>,
745 dtype: IntDType,
746 ) -> IntTensor<B> {
747 Self::int_from_data(
748 TensorData::full_dtype(shape, fill_value, dtype.into()),
749 device,
750 )
751 }
752
753 /// Sums all elements in the tensor.
754 ///
755 /// # Arguments
756 ///
757 /// * `tensor` - The tensor to sum.
758 ///
759 /// # Returns
760 ///
761 /// The sum of all elements in the tensor.
762 fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
763
764 /// Sums all elements in the tensor along a dimension.
765 ///
766 /// # Arguments
767 ///
768 /// * `tensor` - The tensor to sum.
769 /// * `dim` - The dimension to sum along.
770 ///
771 /// # Returns
772 ///
773 /// The sum of all elements in the tensor along the dimension.
774 fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
775
776 /// Computes the product of all elements in the tensor.
777 ///
778 /// # Arguments
779 ///
780 /// * `tensor` - The tensor to compute the product of.
781 ///
782 /// # Returns
783 ///
784 /// The product of all elements in the tensor.
785 fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
786
787 /// Computes the product of all elements in the tensor along a dimension.
788 ///
789 /// # Arguments
790 ///
791 /// * `tensor` - The tensor to compute the product of.
792 /// * `dim` - The dimension to compute the product along.
793 ///
794 /// # Returns
795 ///
796 /// The product of all elements in the tensor along the dimension.
797 fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
798
799 /// Computes the mean of all elements in the tensor.
800 ///
801 /// # Arguments
802 ///
803 /// * `tensor` - The tensor to compute the mean of.
804 ///
805 /// # Returns
806 ///
807 /// The mean of all elements in the tensor.
808 fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
809 let num_elems = tensor.shape().num_elements();
810 B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
811 }
812
813 /// Computes the mean of all elements in the tensor along a dimension.
814 ///
815 /// # Arguments
816 ///
817 /// * `tensor` - The tensor to compute the mean of.
818 ///
819 /// # Returns
820 ///
821 /// The mean of all elements in the tensor along the dimension.
822 fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
823
824 /// Computes the cumulative sum of elements along a dimension.
825 ///
826 /// # Arguments
827 ///
828 /// * `tensor` - The tensor to compute the cumulative sum of.
829 /// * `dim` - The dimension along which to compute the cumulative sum.
830 ///
831 /// # Returns
832 ///
833 /// A tensor with the same shape where each element is the cumulative sum
834 /// of all elements up to and including that position along the dimension.
835 fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
836
837 /// Computes the cumulative product of elements along a dimension.
838 ///
839 /// # Arguments
840 ///
841 /// * `tensor` - The tensor to compute the cumulative product of.
842 /// * `dim` - The dimension along which to compute the cumulative product.
843 ///
844 /// # Returns
845 ///
846 /// A tensor with the same shape where each element is the cumulative product
847 /// of all elements up to and including that position along the dimension.
848 fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
849
850 /// Computes the cumulative minimum of elements along a dimension.
851 ///
852 /// # Arguments
853 ///
854 /// * `tensor` - The tensor to compute the cumulative minimum of.
855 /// * `dim` - The dimension along which to compute the cumulative minimum.
856 ///
857 /// # Returns
858 ///
859 /// A tensor with the same shape where each element is the minimum
860 /// of all elements up to and including that position along the dimension.
861 fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
862
863 /// Computes the cumulative maximum of elements along a dimension.
864 ///
865 /// # Arguments
866 ///
867 /// * `tensor` - The tensor to compute the cumulative maximum of.
868 /// * `dim` - The dimension along which to compute the cumulative maximum.
869 ///
870 /// # Returns
871 ///
872 /// A tensor with the same shape where each element is the maximum
873 /// of all elements up to and including that position along the dimension.
874 fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
875
876 /// Gets the indices of the maximum elements along a dimension.
877 ///
878 /// # Arguments
879 ///
880 /// * `tensor` - The tensor to get the maximum indices of.
881 /// * `dim` - The dimension to get the maximum indices along.
882 ///
883 /// # Returns
884 ///
885 /// The indices of the maximum elements along the dimension.
886 fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
887
888 /// Gets the indices of the minimum elements along a dimension.
889 ///
890 /// # Arguments
891 ///
892 /// * `tensor` - The tensor to get the minimum indices of.
893 /// * `dim` - The dimension to get the minimum indices along.
894 ///
895 /// # Returns
896 ///
897 /// The indices of the minimum elements along the dimension.
898 fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
899
900 /// Gets the maximum element in the tensor.
901 ///
902 /// # Arguments
903 ///
904 /// * `tensor` - The tensor to get the maximum element of.
905 ///
906 /// # Returns
907 ///
908 /// The maximum element in the tensor.
909 fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
910 let shape = tensor.shape();
911 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
912
913 B::int_max_dim(tensor, 0)
914 }
915
916 /// Gets the maximum element in the tensor along a dimension.
917 ///
918 /// # Arguments
919 ///
920 /// * `tensor` - The tensor to get the maximum element of.
921 /// * `dim` - The dimension to get the maximum element along.
922 ///
923 /// # Returns
924 ///
925 /// The maximum element in the tensor along the dimension.
926 fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
927 let index = B::int_argmax(tensor.clone(), dim);
928 B::int_gather(dim, tensor, index)
929 }
930
931 /// Gets the maximum elements and corresponding indices along a dimension.
932 ///
933 /// # Arguments
934 ///
935 /// * `tensor` - The tensor to get the maximum elements and indices of.
936 /// * `dim` - The dimension to get the maximum elements and indices along.
937 ///
938 /// # Returns
939 ///
940 /// The maximum elements and corresponding indices along the dimension.
941 fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
942 let index = B::int_argmax(tensor.clone(), dim);
943 let values = B::int_gather(dim, tensor, index.clone());
944
945 (values, index)
946 }
947
948 /// Gets the maximum absolute element in the tensor.
949 ///
950 /// # Arguments
951 ///
952 /// * `tensor` - The tensor to get the maximum element of.
953 ///
954 /// # Returns
955 ///
956 /// The maximum element in the tensor.
957 fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
958 let shape = tensor.shape();
959 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
960
961 B::int_max_abs_dim(tensor, 0)
962 }
963
964 /// Gets the maximum absolute element in the tensor along a dimension.
965 ///
966 /// # Arguments
967 ///
968 /// * `tensor` - The tensor to get the maximum element of.
969 /// * `dim` - The dimension to get the maximum element along.
970 ///
971 /// # Returns
972 ///
973 /// The maximum element in the tensor along the dimension.
974 fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
975 B::int_max_dim(B::int_abs(tensor), dim)
976 }
977
978 /// Gets the minimum element in the tensor.
979 ///
980 /// # Arguments
981 ///
982 /// * `tensor` - The tensor to get the minimum element of.
983 ///
984 /// # Returns
985 ///
986 /// The minimum element in the tensor.
987 fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
988 let shape = tensor.shape();
989 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
990
991 B::int_min_dim(tensor, 0)
992 }
993
994 /// Gets the minimum elements in the tensor along a dimension.
995 ///
996 /// # Arguments
997 ///
998 /// * `tensor` - The tensor to get the minimum element of.
999 /// * `dim` - The dimension to get the minimum element along.
1000 ///
1001 /// # Returns
1002 ///
1003 /// The minimum element in the tensor along the dimension.
1004 fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1005 let index = B::int_argmin(tensor.clone(), dim);
1006 B::int_gather(dim, tensor, index)
1007 }
1008
1009 /// Gets the minimum elements and corresponding indices along a dimension.
1010 ///
1011 /// # Arguments
1012 ///
1013 /// * `tensor` - The tensor to get the minimum elements and indices of.
1014 /// * `dim` - The dimension to get the minimum elements and indices along.
1015 ///
1016 /// # Returns
1017 ///
1018 /// The minimum elements and corresponding indices along the dimension.
1019 fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1020 let indices = B::int_argmin(tensor.clone(), dim);
1021 let values = B::int_gather(dim, tensor, indices.clone());
1022
1023 (values, indices)
1024 }
1025
1026 /// Returns a new tensor with absolute values.
1027 ///
1028 /// # Arguments
1029 ///
1030 /// * `tensor` - The tensor to take absolute value of.
1031 ///
1032 /// # Returns
1033 ///
1034 /// A tensor with the same shape as `tensor` with absolute values.
1035 fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1036
1037 /// Transposes an int tensor.
1038 ///
1039 /// # Arguments
1040 ///
1041 /// * `tensor` - The tensor to transpose.
1042 ///
1043 /// # Returns
1044 ///
1045 /// The transposed tensor.
1046 fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1047 let ndims = tensor.shape().num_dims();
1048 Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1049 }
1050
1051 /// Swaps two dimensions of an int tensor.
1052 ///
1053 /// # Arguments
1054 ///
1055 /// * `tensor` - The tensor to swap the dimensions of.
1056 /// * `dim1` - The first dimension to swap.
1057 /// * `dim2` - The second dimension to swap.
1058 ///
1059 /// # Returns
1060 ///
1061 /// The tensor with the dimensions swapped.
1062 fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1063
1064 /// Permutes the dimensions of a tensor.
1065 ///
1066 /// # Arguments
1067 ///
1068 /// * `tensor` - The tensor to permute the dimensions of.
1069 /// * `axes` - The new order of the dimensions.
1070 /// # Returns
1071 ///
1072 /// The tensor with the dimensions permuted.
1073 fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1074
1075 /// Reverse the order of elements in a tensor along the given axes.
1076 ///
1077 /// # Arguments
1078 ///
1079 /// * `tensor` - The tensor to reverse.
1080 /// * `axes` - The axes to reverse.
1081 ///
1082 /// The tensor with the elements reversed.
1083 fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1084
1085 /// Creates a new int tensor with random values.
1086 ///
1087 /// # Arguments
1088 /// * `shape` - The shape of the tensor.
1089 /// * `distribution` - The distribution to sample from.
1090 /// * `device` - The device to create the tensor on.
1091 ///
1092 /// # Returns
1093 ///
1094 /// The tensor with the given shape and random values.
1095 fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
1096
1097 /// Creates a new tensor with values from the given range with the given step size.
1098 ///
1099 /// # Arguments
1100 ///
1101 /// * `range` - The range of values.
1102 /// * `step` - The step size.
1103 /// * `device` - The device to create the tensor on.
1104 ///
1105 /// # Returns
1106 ///
1107 /// The tensor with the given values.
1108 fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1109 let value = range
1110 .step_by(step)
1111 .map(|i| i.elem())
1112 .collect::<Vec<IntElem<B>>>();
1113 let shape = Shape::new([value.len()]);
1114 let data = TensorData::new(value, shape);
1115 B::int_from_data(data, device)
1116 }
1117
1118 /// Creates a new tensor with values from the given range.
1119 ///
1120 /// # Arguments
1121 ///
1122 /// * `range` - The range of values.
1123 /// * `device` - The device to create the tensor on.
1124 ///
1125 /// # Returns
1126 ///
1127 /// The tensor with the given values.
1128 ///
1129 /// # Remarks
1130 ///
1131 /// Uses `arange_step` with a step size of 1 under the hood.
1132 fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1133 Self::int_arange_step(range, 1, device)
1134 }
1135
1136 /// Tests if any element in the int `tensor` evaluates to True.
1137 ///
1138 /// # Arguments
1139 ///
1140 /// * `tensor` - The tensor to test.
1141 ///
1142 /// # Returns
1143 ///
1144 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1145 fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1146 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1147 let bool_tensor = B::bool_not(bool_tensor);
1148 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1149 B::int_greater_elem(sum, 0.elem())
1150 }
1151
1152 /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1153 ///
1154 /// # Arguments
1155 ///
1156 /// * `tensor` - The tensor to test.
1157 /// * `dim` - The axis along which to test.
1158 ///
1159 /// # Returns
1160 ///
1161 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1162 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1163 /// evaluates to True, False otherwise.
1164 fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1165 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1166 let bool_tensor = B::bool_not(bool_tensor);
1167 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1168 B::int_greater_elem(sum, 0.elem())
1169 }
1170
1171 /// Tests if all elements in the int `tensor` evaluate to True.
1172 ///
1173 /// # Arguments
1174 ///
1175 /// * `tensor` - The tensor to test.
1176 ///
1177 /// # Returns
1178 ///
1179 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1180 /// evaluate to True, False otherwise.
1181 fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1182 let num_elems = tensor.shape().num_elements();
1183 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1184 let bool_tensor = B::bool_not(bool_tensor);
1185 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1186 B::int_equal_elem(sum, (num_elems as i32).elem())
1187 }
1188
1189 /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1190 ///
1191 /// # Arguments
1192 ///
1193 /// * `tensor` - The tensor to test.
1194 /// * `dim` - The axis along which to test.
1195 ///
1196 /// # Returns
1197 ///
1198 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1199 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1200 /// evaluates to True, False otherwise.
1201 fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1202 let num_elems = tensor.shape().dims[dim];
1203 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1204 let bool_tensor = B::bool_not(bool_tensor);
1205 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1206 B::int_equal_elem(sum, (num_elems as i32).elem())
1207 }
1208
1209 /// Returns the signs of the int `tensor`.
1210 ///
1211 /// # Arguments
1212 ///
1213 /// * `tensor` - The tensor to extract the signs from.
1214 ///
1215 /// # Returns
1216 ///
1217 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1218 fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1219 let dtype = tensor.dtype();
1220 let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor), dtype.into());
1221 let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1222 let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1223
1224 let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1225 result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1226 result
1227 }
1228
1229 /// Broadcasts the int `tensor` to the given `shape`.
1230 fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1231
1232 /// Sort the elements of the input `tensor` by value along a given dimension.
1233 ///
1234 /// This sort is unstable (i.e., may reorder equal elements).
1235 ///
1236 /// # Arguments
1237 ///
1238 /// * `tensor` - The input tensor.
1239 /// * `dim` - The axis along which to sort.
1240 /// * `descending` - The sorting order.
1241 ///
1242 /// # Returns
1243 ///
1244 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1245 fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1246 sort::<B, Int>(tensor, dim, descending)
1247 }
1248
1249 /// Sort the elements of the input `tensor` by value along a given dimension.
1250 ///
1251 /// This sort is unstable (i.e., may reorder equal elements).
1252 ///
1253 /// # Arguments
1254 ///
1255 /// * `tensor` - The input tensor.
1256 /// * `dim` - The axis along which to sort.
1257 ///
1258 /// # Returns
1259 ///
1260 /// A tensor with the same shape as the input tensor and corresponding indices, where
1261 /// the elements are sorted by value and the indices map back to the original input tensor.
1262 fn int_sort_with_indices(
1263 tensor: IntTensor<B>,
1264 dim: usize,
1265 descending: bool,
1266 ) -> (IntTensor<B>, IntTensor<B>) {
1267 sort_with_indices::<B, Int>(tensor, dim, descending)
1268 }
1269
1270 /// Returns the indices that sort the elements of the input `tensor` by value
1271 /// along a given dimension.
1272 ///
1273 /// This sort is unstable (i.e., may reorder equal elements).
1274 ///
1275 /// # Arguments
1276 ///
1277 /// * `tensor` - The input tensor.
1278 /// * `dim` - The axis along which to sort.
1279 /// * `descending` - The sorting order.
1280 ///
1281 /// # Returns
1282 ///
1283 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1284 fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1285 argsort::<B, Int>(tensor, dim, descending)
1286 }
1287
1288 /// Bitwise AND operation for Int Tensors
1289 fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1290
1291 /// Bitwise AND operation for Int Tensors with a scalar
1292 fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1293
1294 /// Bitwise OR operation for Int Tensors
1295 fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1296
1297 /// Bitwise OR operation for Int Tensors with a scalar
1298 fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1299
1300 /// Bitwise XOR operation for Int Tensors
1301 fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1302
1303 /// Bitwise XOR operation for Int Tensors with a scalar
1304 fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1305
1306 /// Bitwise NOT operation for Int Tensors
1307 fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1308
1309 /// Bitwise left shift operation for Int Tensors
1310 fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1311
1312 /// Bitwise left shift operation for Int Tensors with a scalar
1313 fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1314
1315 /// Bitwise right shift operation for Int Tensors
1316 fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1317
1318 /// Bitwise right shift operation for Int Tensors with a scalar
1319 fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1320
1321 /// Converts a tensor to another integer data type.
1322 ///
1323 /// # Arguments
1324 ///
1325 /// * `tensor` - The tensor to convert.
1326 /// * `dtype` - The target data type.
1327 ///
1328 /// # Returns
1329 ///
1330 /// A tensor with the same values as `tensor` but in the target integer data type.
1331 fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1332
1333 /// Unfold windows along a dimension.
1334 ///
1335 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1336 /// where windows are advanced by `step` at each index.
1337 ///
1338 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1339 ///
1340 /// # Arguments
1341 ///
1342 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1343 /// * `dim` - the selected dim.
1344 /// * `size` - the size of each unfolded window.
1345 /// * `step` - the step between each window.
1346 ///
1347 /// # Returns
1348 ///
1349 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1350 fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1351}