burn_backend/backend/ops/int_tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::sort::{argsort, sort, sort_with_indices};
4use crate::ExecutionError;
5use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};
6use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};
7use alloc::vec::Vec;
8use burn_std::{IntDType, Shape, Slice};
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16 /// Creates a new int tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 /// * `dtype` - The target data type.
23 ///
24 /// # Returns
25 ///
26 /// The integer tensor with the given shape.
27 fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
28
29 /// Converts the tensor to a data structure.
30 ///
31 /// # Arguments
32 ///
33 /// * `tensor` - The tensor.
34 ///
35 /// # Returns
36 ///
37 /// The data structure with the tensor's data.
38 fn int_into_data(
39 tensor: IntTensor<B>,
40 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
41
42 /// Creates a tensor from the data structure.
43 ///
44 /// # Arguments
45 ///
46 /// * `data` - The data structure.
47 /// * `device` - The device to create the tensor on.
48 ///
49 /// # Returns
50 ///
51 /// The tensor with the data.
52 fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
53
54 /// Gets the device of the tensor.
55 ///
56 /// # Arguments
57 ///
58 /// * `tensor` - The tensor.
59 ///
60 /// # Returns
61 ///
62 /// The device of the tensor.
63 fn int_device(tensor: &IntTensor<B>) -> Device<B>;
64
65 /// Moves the tensor to the given device.
66 fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
67
68 /// Reshapes the tensor.
69 ///
70 /// # Arguments
71 ///
72 /// * `tensor` - The tensor.
73 /// * `shape` - The new shape.
74 ///
75 /// # Returns
76 ///
77 /// The tensor with the new shape.
78 fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
79
80 /// Gets the element at the given indices.
81 ///
82 /// # Arguments
83 ///
84 /// * `tensor` - The tensor.
85 /// * `slices` - The slices specifying ranges and steps for each dimension.
86 ///
87 /// # Returns
88 ///
89 /// The elements at the given indices.
90 ///
91 /// # Note
92 ///
93 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
94 /// be passed to this method. Backend implementations do not need to handle empty slices.
95 fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;
96
97 /// Sets the values in the tensor for the given ranges.
98 ///
99 /// # Arguments
100 ///
101 /// * `tensor` - The tensor.
102 /// * `ranges` - The ranges to set the values for.
103 ///
104 /// # Returns
105 ///
106 /// The tensor with the values set for the given ranges.
107 ///
108 /// # Note
109 ///
110 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
111 /// high-level tensor API and will not be passed to this method. Backend implementations do
112 /// not need to handle empty slice assignments.
113 fn int_slice_assign(
114 tensor: IntTensor<B>,
115 slices: &[Slice],
116 value: IntTensor<B>,
117 ) -> IntTensor<B>;
118
119 /// Converts int tensor to float tensor.
120 ///
121 /// # Arguments
122 ///
123 /// * `tensor` - The tensor.
124 ///
125 /// # Returns
126 ///
127 /// The int tensor with the same data as the float tensor.
128 fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;
129
130 /// Fills the tensor with values from the source tensor if the mask is true at the given
131 /// indices.
132 ///
133 /// # Arguments
134 ///
135 /// * `tensor` - The tensor.
136 /// * `mask` - The mask.
137 /// * `source` - The source tensor.
138 ///
139 /// # Returns
140 ///
141 /// The tensor with the values filled.
142 fn int_mask_where(
143 tensor: IntTensor<B>,
144 mask: BoolTensor<B>,
145 source: IntTensor<B>,
146 ) -> IntTensor<B>;
147
148 /// Fills the tensor with the given value if the mask is true at the given indices.
149 ///
150 /// # Arguments
151 ///
152 /// * `tensor` - The tensor.
153 /// * `mask` - The mask.
154 /// * `value` - The value.
155 ///
156 /// # Returns
157 ///
158 /// The tensor with the values filled.
159 fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>;
160
161 /// Gather elements from the tensor at the given indices.
162 ///
163 /// # Arguments
164 ///
165 /// * `dim` - The dimension to gather from.
166 /// * `tensor` - The tensor.
167 /// * `indices` - The indices.
168 fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
169
170 /// Scatter a given value to the tensor at the given indices using sum reduction.
171 ///
172 /// # Arguments
173 ///
174 /// * `dim` - The dimension to scatter to.
175 /// * `tensor` - The tensor.
176 /// * `indices` - The indices.
177 /// * `value` - The value.
178 ///
179 /// # Returns
180 ///
181 /// The tensor with the values scattered.
182 fn int_scatter_add(
183 dim: usize,
184 tensor: IntTensor<B>,
185 indices: IntTensor<B>,
186 value: IntTensor<B>,
187 ) -> IntTensor<B>;
188
189 /// Select tensor elements along the given dimension corresponding to the given indices.
190 ///
191 /// # Arguments
192 ///
193 /// * `tensor` - The tensor.
194 /// * `dim` - The dimension to select from.
195 /// * `indices` - The indices.
196 ///
197 /// # Returns
198 ///
199 /// The tensor with the selected elements.
200 fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
201
202 /// Assign the selected elements along the given dimension corresponding to the given indices
203 /// to the given value using sum reduction.
204 ///
205 /// # Arguments
206 ///
207 /// * `tensor` - The tensor.
208 /// * `dim` - The dimension to select from.
209 /// * `indices` - The indices.
210 /// * `value` - The value.
211 ///
212 /// # Returns
213 ///
214 /// The tensor with the selected elements assigned to the given value.
215 fn int_select_add(
216 tensor: IntTensor<B>,
217 dim: usize,
218 indices: IntTensor<B>,
219 value: IntTensor<B>,
220 ) -> IntTensor<B>;
221
222 /// Repeats the tensor along the given dimension the given number of times.
223 ///
224 /// # Arguments
225 ///
226 /// * `tensor` - The tensor.
227 /// * `dim` - The dimension to repeat.
228 /// * `times` - The number of times to repeat.
229 ///
230 /// # Returns
231 ///
232 /// The tensor with the given dimension repeated the given number of times.
233 fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
234 repeat_with_slice_assign::<B, Int>(tensor, dim, times)
235 }
236
237 /// Concatenates the given tensors along the given dimension.
238 ///
239 /// # Arguments
240 ///
241 /// * `tensors` - The tensors.
242 /// * `dim` - The dimension to concatenate along.
243 ///
244 /// # Returns
245 ///
246 /// The concatenated tensor.
247 ///
248 /// # Note
249 ///
250 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
251 /// high-level tensor API and will not be passed to this method. Backend implementations do
252 /// not need to handle empty tensors.
253 fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
254 cat_with_slice_assign::<B, Int>(tensors, dim)
255 }
256
257 /// Element-wise equality comparison.
258 ///
259 /// # Arguments
260 ///
261 /// * `lhs` - The left-hand side tensor.
262 /// * `rhs` - The right-hand side tensor.
263 ///
264 /// # Returns
265 ///
266 /// The boolean tensor with the result of the comparison.
267 fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
268
269 /// Element-wise non-equality comparison.
270 ///
271 /// # Arguments
272 ///
273 /// * `lhs` - The left-hand side tensor.
274 /// * `rhs` - The right-hand side tensor.
275 ///
276 /// # Returns
277 ///
278 /// The boolean tensor with the result of the comparison.
279 fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
280 let equal_tensor = B::int_equal(lhs, rhs);
281 B::bool_not(equal_tensor)
282 }
283
284 /// Element-wise equality comparison with a scalar.
285 ///
286 /// # Arguments
287 ///
288 /// * `lhs` - The left-hand side tensor.
289 /// * `rhs` - The right-hand side scalar.
290 ///
291 /// # Returns
292 ///
293 /// The boolean tensor with the result of the comparison.
294 fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
295
296 /// Element-wise non-equality comparison with a scalar.
297 ///
298 /// # Arguments
299 ///
300 /// * `lhs` - The left-hand side tensor.
301 /// * `rhs` - The right-hand side scalar.
302 ///
303 /// # Returns
304 ///
305 /// The boolean tensor with the result of the comparison.
306 fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> {
307 let equal_tensor = B::int_equal_elem(lhs, rhs);
308 B::bool_not(equal_tensor)
309 }
310
311 /// Element-wise greater than comparison.
312 ///
313 /// # Arguments
314 ///
315 /// * `lhs` - The left-hand side tensor.
316 /// * `rhs` - The right-hand side tensor.
317 ///
318 /// # Returns
319 ///
320 /// The boolean tensor with the result of the comparison.
321 fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
322
323 /// Element-wise greater than comparison with a scalar.
324 ///
325 /// # Arguments
326 ///
327 /// * `lhs` - The left-hand side tensor.
328 /// * `rhs` - The right-hand side scalar.
329 ///
330 /// # Returns
331 ///
332 /// The boolean tensor with the result of the comparison.
333 fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
334
335 /// Element-wise greater than or equal comparison.
336 ///
337 /// # Arguments
338 ///
339 /// * `lhs` - The left-hand side tensor.
340 /// * `rhs` - The right-hand side tensor.
341 ///
342 /// # Returns
343 ///
344 /// The boolean tensor with the result of the comparison.
345 fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
346
347 /// Element-wise greater than or equal comparison with a scalar.
348 ///
349 /// # Arguments
350 ///
351 /// * `lhs` - The left-hand side tensor.
352 /// * `rhs` - The right-hand side scalar.
353 ///
354 /// # Returns
355 ///
356 /// The boolean tensor with the result of the comparison.
357 fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
358
359 /// Element-wise less than comparison.
360 ///
361 /// # Arguments
362 ///
363 /// * `lhs` - The left-hand side tensor.
364 /// * `rhs` - The right-hand side tensor.
365 ///
366 /// # Returns
367 ///
368 /// The boolean tensor with the result of the comparison.
369 fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
370
371 /// Element-wise less than comparison with a scalar.
372 ///
373 /// # Arguments
374 ///
375 /// * `lhs` - The left-hand side tensor.
376 /// * `rhs` - The right-hand side scalar.
377 ///
378 /// # Returns
379 ///
380 /// The boolean tensor with the result of the comparison.
381 fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
382
383 /// Element-wise less than or equal comparison.
384 ///
385 /// # Arguments
386 ///
387 /// * `lhs` - The left-hand side tensor.
388 /// * `rhs` - The right-hand side tensor.
389 ///
390 /// # Returns
391 ///
392 /// The boolean tensor with the result of the comparison.
393 fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;
394
395 /// Element-wise less than or equal comparison with a scalar.
396 ///
397 /// # Arguments
398 ///
399 /// * `lhs` - The left-hand side tensor.
400 /// * `rhs` - The right-hand side scalar.
401 ///
402 /// # Returns
403 ///
404 /// The boolean tensor with the result of the comparison.
405 fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>;
406
407 // ==== NUMERIC ==== //
408
409 /// Element-wise addition.
410 ///
411 /// # Arguments
412 ///
413 /// * `lhs` - The left-hand side tensor.
414 /// * `rhs` - The right-hand side tensor.
415 ///
416 /// # Returns
417 ///
418 /// The result of the addition.
419 fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
420
421 /// Element-wise addition with a scalar.
422 ///
423 /// # Arguments
424 ///
425 /// * `lhs` - The left-hand side tensor.
426 /// * `rhs` - The right-hand side scalar.
427 ///
428 /// # Returns
429 ///
430 /// The result of the addition.
431 fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
432
433 /// Element-wise power with a IntTensor.
434 ///
435 /// # Arguments
436 ///
437 /// * `lhs` - The left-hand side IntTensor.
438 /// * `rhs` - The right-hand side IntTensor.
439 ///
440 /// # Returns
441 ///
442 /// The elements of `lhs` raised to the power of the elements of `rhs`.
443 fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
444 B::float_into_int(B::float_powi(B::int_into_float(lhs), rhs))
445 }
446
447 /// Element-wise power with a floatTensor.
448 ///
449 /// # Arguments
450 ///
451 /// * `lhs` - The left-hand side tensor.
452 /// * `rhs` - The right-hand side floatTensor.
453 ///
454 /// # Returns
455 ///
456 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
457 fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> {
458 B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs))
459 }
460
461 /// Element-wise power with a scalar.
462 ///
463 /// # Backend Implementors Note
464 ///
465 /// A number of common exponent cases can be implemented with operations
466 /// which are much cheaper than generic exponentiation.
467 ///
468 /// This (`Backend` impl overridable) operation handles generic optimizations
469 /// for several common integer exponent cases; and then dispatches to
470 /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
471 /// operation to handle the generic case.
472 ///
473 /// # Arguments
474 ///
475 /// * `lhs` - The left-hand side tensor.
476 /// * `rhs` - The right-hand side scalar.
477 ///
478 /// # Returns
479 ///
480 /// The elements of `lhs` raised to the value of `rhs`.
481 fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
482 let exp = rhs.elem::<i32>();
483 match exp {
484 0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
485 1 => lhs,
486 2 => Self::int_mul(lhs.clone(), lhs),
487 _ => Self::int_powi_scalar_impl(lhs, rhs),
488 }
489 }
490
491 /// Element-wise power with a scalar.
492 ///
493 /// # Backend Implementors Note
494 ///
495 /// This is the generic implementation of integer exponentiation
496 /// called by [`Self::int_powi_scalar`] in the fallback case.
497 ///
498 /// By default, this performs a relatively expensive conversion to float,
499 /// exponentiation in float, and conversion back to int.
500 /// This reduces the minimal operation set for `Backend`s,
501 /// at the cost of performance.
502 ///
503 /// This is a good target for specialized optimizations in `Backend` implementations.
504 ///
505 /// As a general rule, this should not be called directly.
506 ///
507 /// # Arguments
508 ///
509 /// * `lhs` - The left-hand side tensor.
510 /// * `rhs` - The right-hand side scalar.
511 ///
512 /// # Returns
513 ///
514 /// The elements of `lhs` raised to the value of `rhs`.
515 fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> {
516 B::float_into_int(B::float_powi_scalar_impl(B::int_into_float(lhs), rhs))
517 }
518
519 /// Element-wise power with a floatTensor.
520 ///
521 /// Handles a number of special cases, then calls [`Self::int_powf_scalar_impl`].
522 ///
523 /// # Arguments
524 ///
525 /// * `lhs` - The left-hand side tensor.
526 /// * `rhs` - The right-hand side scalar.
527 ///
528 /// # Returns
529 ///
530 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
531 fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
532 if num_traits::Float::floor(rhs) == rhs {
533 let exp = B::IntElem::from_elem(rhs as i32);
534 Self::int_powi_scalar(lhs, exp)
535 } else {
536 Self::int_powf_scalar_impl(lhs, rhs)
537 }
538 }
539
540 /// Element-wise power with a floatTensor.
541 ///
542 /// Fallback handler for [`Self::int_powf_scalar`].
543 ///
544 /// # Arguments
545 ///
546 /// * `lhs` - The left-hand side tensor.
547 /// * `rhs` - The right-hand side scalar.
548 ///
549 /// # Returns
550 ///
551 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
552 fn int_powf_scalar_impl(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> {
553 B::float_into_int(B::float_powf_scalar_impl(B::int_into_float(lhs), rhs))
554 }
555
556 /// Clamps a tensor under a minimum value.
557 ///
558 /// # Arguments
559 ///
560 /// * `tensor` - The tensor to clamp.
561 /// * `min` - The minimum value.
562 ///
563 /// # Returns
564 ///
565 /// The clamped tensor.
566 fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> {
567 let mask = Self::int_lower_elem(tensor.clone(), min);
568 Self::int_mask_fill(tensor, mask, min)
569 }
570
571 /// Clamps a tensor over a maximum value.
572 ///
573 /// # Arguments
574 ///
575 /// * `tensor` - The tensor to clamp.
576 /// * `max` - The maximum value.
577 ///
578 /// # Returns
579 ///
580 /// The clamped tensor.
581 fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> {
582 let mask = Self::int_greater_elem(tensor.clone(), max);
583 Self::int_mask_fill(tensor, mask, max)
584 }
585
586 /// Clamps a tensor between a minimum and maximum value.
587 ///
588 /// # Arguments
589 ///
590 /// * `tensor` - The tensor to clamp.
591 /// * `min` - The minimum value.
592 /// * `max` - The maximum value.
593 ///
594 /// # Returns
595 ///
596 /// The clamped tensor.
597 fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> {
598 Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
599 }
600
601 /// Element-wise subtraction.
602 ///
603 /// # Arguments
604 ///
605 /// * `lhs` - The left-hand side tensor.
606 /// * `rhs` - The right-hand side tensor.
607 ///
608 /// # Returns
609 ///
610 /// The result of the subtraction.
611 fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
612
613 /// Element-wise subtraction with a scalar.
614 ///
615 /// # Arguments
616 ///
617 /// * `lhs` - The left-hand side tensor.
618 /// * `rhs` - The right-hand side scalar.
619 ///
620 /// # Returns
621 ///
622 /// The result of the subtraction.
623 fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
624
625 /// Element-wise multiplication.
626 ///
627 /// # Arguments
628 ///
629 /// * `lhs` - The left-hand side tensor.
630 /// * `rhs` - The right-hand side tensor.
631 ///
632 /// # Returns
633 ///
634 /// The result of the multiplication.
635 fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
636
637 /// Element-wise multiplication with a scalar.
638 ///
639 /// # Arguments
640 ///
641 /// * `lhs` - The left-hand side tensor.
642 /// * `rhs` - The right-hand side scalar.
643 ///
644 /// # Returns
645 ///
646 /// The result of the multiplication.
647 fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
648
649 /// Element-wise division.
650 ///
651 /// # Arguments
652 ///
653 /// * `lhs` - The left-hand side tensor.
654 /// * `rhs` - The right-hand side tensor.
655 ///
656 /// # Returns
657 ///
658 /// The result of the division.
659 fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
660
661 /// Element-wise division with a scalar.
662 ///
663 /// # Arguments
664 ///
665 /// * `lhs` - The left-hand side tensor.
666 /// * `rhs` - The right-hand side scalar.
667 ///
668 /// # Returns
669 ///
670 /// The result of the division.
671 fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
672
673 /// Element-wise modulus.
674 ///
675 /// # Arguments
676 /// * `lhs` - The left-hand side tensor.
677 /// * `rhs` - The right-hand side scalar.
678 ///
679 /// # Returns
680 ///
681 /// The result of applying the modulus of the scalar to the tensor.
682 fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
683
684 /// Element-wise modulus with a scalar.
685 ///
686 /// # Arguments
687 /// * `lhs` - The left-hand side tensor.
688 /// * `rhs` - The right-hand side scalar.
689 ///
690 /// # Returns
691 ///
692 /// The result of applying the modulus of the scalar to the tensor.
693 fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
694
695 /// Multiplies two tensors together using matrix multiplication.
696 ///
697 /// # Arguments
698 ///
699 /// * `lhs` - The left-hand side tensor.
700 /// * `rhs` - The right-hand side tensor.
701 ///
702 /// # Returns
703 ///
704 /// The result of multiplying the two tensors together using matrix multiplication.
705 fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
706
707 /// Element-wise negation.
708 ///
709 /// # Arguments
710 ///
711 /// * `tensor` - The tensor to negate.
712 ///
713 /// # Returns
714 ///
715 /// The negated tensor.
716 fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
717 Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>())
718 }
719
720 /// Creates a tensor of zeros.
721 ///
722 /// # Arguments
723 ///
724 /// * `shape` - The shape of the tensor.
725 /// * `device` - The device to create the tensor on.
726 /// * `dtype` - The target data type.
727 ///
728 /// # Returns
729 ///
730 /// The tensor of zeros.
731 fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
732 Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
733 }
734
735 /// Creates a tensor of ones.
736 ///
737 /// # Arguments
738 ///
739 /// * `shape` - The shape of the tensor.
740 /// * `device` - The device to create the tensor on.
741 /// * `dtype` - The target data type.
742 ///
743 /// # Returns
744 ///
745 /// The tensor of ones.
746 fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
747 Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
748 }
749
750 /// Creates a tensor filled with given value.
751 ///
752 /// # Arguments
753 ///
754 /// * `shape` - The shape of the tensor.
755 /// * `fill_value` - The value with which to fill the tensor.
756 /// * `device` - The device to create the tensor on.
757 /// * `dtype` - The target data type.
758 ///
759 /// # Returns
760 ///
761 /// The tensor filled with given value
762 fn int_full(
763 shape: Shape,
764 fill_value: IntElem<B>,
765 device: &Device<B>,
766 dtype: IntDType,
767 ) -> IntTensor<B> {
768 Self::int_from_data(
769 TensorData::full_dtype(shape, fill_value, dtype.into()),
770 device,
771 )
772 }
773
774 /// Sums all elements in the tensor.
775 ///
776 /// # Arguments
777 ///
778 /// * `tensor` - The tensor to sum.
779 ///
780 /// # Returns
781 ///
782 /// The sum of all elements in the tensor.
783 fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
784
785 /// Sums all elements in the tensor along a dimension.
786 ///
787 /// # Arguments
788 ///
789 /// * `tensor` - The tensor to sum.
790 /// * `dim` - The dimension to sum along.
791 ///
792 /// # Returns
793 ///
794 /// The sum of all elements in the tensor along the dimension.
795 fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
796
797 /// Computes the product of all elements in the tensor.
798 ///
799 /// # Arguments
800 ///
801 /// * `tensor` - The tensor to compute the product of.
802 ///
803 /// # Returns
804 ///
805 /// The product of all elements in the tensor.
806 fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
807
808 /// Computes the product of all elements in the tensor along a dimension.
809 ///
810 /// # Arguments
811 ///
812 /// * `tensor` - The tensor to compute the product of.
813 /// * `dim` - The dimension to compute the product along.
814 ///
815 /// # Returns
816 ///
817 /// The product of all elements in the tensor along the dimension.
818 fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
819
820 /// Computes the mean of all elements in the tensor.
821 ///
822 /// # Arguments
823 ///
824 /// * `tensor` - The tensor to compute the mean of.
825 ///
826 /// # Returns
827 ///
828 /// The mean of all elements in the tensor.
829 fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
830 let num_elems = tensor.shape().num_elements();
831 B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
832 }
833
834 /// Computes the mean of all elements in the tensor along a dimension.
835 ///
836 /// # Arguments
837 ///
838 /// * `tensor` - The tensor to compute the mean of.
839 ///
840 /// # Returns
841 ///
842 /// The mean of all elements in the tensor along the dimension.
843 fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
844
845 /// Computes the cumulative sum of elements along a dimension.
846 ///
847 /// # Arguments
848 ///
849 /// * `tensor` - The tensor to compute the cumulative sum of.
850 /// * `dim` - The dimension along which to compute the cumulative sum.
851 ///
852 /// # Returns
853 ///
854 /// A tensor with the same shape where each element is the cumulative sum
855 /// of all elements up to and including that position along the dimension.
856 fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
857
858 /// Computes the cumulative product of elements along a dimension.
859 ///
860 /// # Arguments
861 ///
862 /// * `tensor` - The tensor to compute the cumulative product of.
863 /// * `dim` - The dimension along which to compute the cumulative product.
864 ///
865 /// # Returns
866 ///
867 /// A tensor with the same shape where each element is the cumulative product
868 /// of all elements up to and including that position along the dimension.
869 fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
870
871 /// Computes the cumulative minimum of elements along a dimension.
872 ///
873 /// # Arguments
874 ///
875 /// * `tensor` - The tensor to compute the cumulative minimum of.
876 /// * `dim` - The dimension along which to compute the cumulative minimum.
877 ///
878 /// # Returns
879 ///
880 /// A tensor with the same shape where each element is the minimum
881 /// of all elements up to and including that position along the dimension.
882 fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
883
884 /// Computes the cumulative maximum of elements along a dimension.
885 ///
886 /// # Arguments
887 ///
888 /// * `tensor` - The tensor to compute the cumulative maximum of.
889 /// * `dim` - The dimension along which to compute the cumulative maximum.
890 ///
891 /// # Returns
892 ///
893 /// A tensor with the same shape where each element is the maximum
894 /// of all elements up to and including that position along the dimension.
895 fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
896
897 /// Gets the indices of the maximum elements along a dimension.
898 ///
899 /// # Arguments
900 ///
901 /// * `tensor` - The tensor to get the maximum indices of.
902 /// * `dim` - The dimension to get the maximum indices along.
903 ///
904 /// # Returns
905 ///
906 /// The indices of the maximum elements along the dimension.
907 fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
908
909 /// Gets the indices of the minimum elements along a dimension.
910 ///
911 /// # Arguments
912 ///
913 /// * `tensor` - The tensor to get the minimum indices of.
914 /// * `dim` - The dimension to get the minimum indices along.
915 ///
916 /// # Returns
917 ///
918 /// The indices of the minimum elements along the dimension.
919 fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
920
921 /// Gets the maximum element in the tensor.
922 ///
923 /// # Arguments
924 ///
925 /// * `tensor` - The tensor to get the maximum element of.
926 ///
927 /// # Returns
928 ///
929 /// The maximum element in the tensor.
930 fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
931 let shape = tensor.shape();
932 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
933
934 B::int_max_dim(tensor, 0)
935 }
936
937 /// Gets the maximum element in the tensor along a dimension.
938 ///
939 /// # Arguments
940 ///
941 /// * `tensor` - The tensor to get the maximum element of.
942 /// * `dim` - The dimension to get the maximum element along.
943 ///
944 /// # Returns
945 ///
946 /// The maximum element in the tensor along the dimension.
947 fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
948 let index = B::int_argmax(tensor.clone(), dim);
949 B::int_gather(dim, tensor, index)
950 }
951
952 /// Gets the maximum elements and corresponding indices along a dimension.
953 ///
954 /// # Arguments
955 ///
956 /// * `tensor` - The tensor to get the maximum elements and indices of.
957 /// * `dim` - The dimension to get the maximum elements and indices along.
958 ///
959 /// # Returns
960 ///
961 /// The maximum elements and corresponding indices along the dimension.
962 fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
963 let index = B::int_argmax(tensor.clone(), dim);
964 let values = B::int_gather(dim, tensor, index.clone());
965
966 (values, index)
967 }
968
969 /// Gets the maximum absolute element in the tensor.
970 ///
971 /// # Arguments
972 ///
973 /// * `tensor` - The tensor to get the maximum element of.
974 ///
975 /// # Returns
976 ///
977 /// The maximum element in the tensor.
978 fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
979 let shape = tensor.shape();
980 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
981
982 B::int_max_abs_dim(tensor, 0)
983 }
984
985 /// Gets the maximum absolute element in the tensor along a dimension.
986 ///
987 /// # Arguments
988 ///
989 /// * `tensor` - The tensor to get the maximum element of.
990 /// * `dim` - The dimension to get the maximum element along.
991 ///
992 /// # Returns
993 ///
994 /// The maximum element in the tensor along the dimension.
995 fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
996 B::int_max_dim(B::int_abs(tensor), dim)
997 }
998
999 /// Gets the minimum element in the tensor.
1000 ///
1001 /// # Arguments
1002 ///
1003 /// * `tensor` - The tensor to get the minimum element of.
1004 ///
1005 /// # Returns
1006 ///
1007 /// The minimum element in the tensor.
1008 fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
1009 let shape = tensor.shape();
1010 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1011
1012 B::int_min_dim(tensor, 0)
1013 }
1014
1015 /// Gets the minimum elements in the tensor along a dimension.
1016 ///
1017 /// # Arguments
1018 ///
1019 /// * `tensor` - The tensor to get the minimum element of.
1020 /// * `dim` - The dimension to get the minimum element along.
1021 ///
1022 /// # Returns
1023 ///
1024 /// The minimum element in the tensor along the dimension.
1025 fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1026 let index = B::int_argmin(tensor.clone(), dim);
1027 B::int_gather(dim, tensor, index)
1028 }
1029
1030 /// Gets the minimum elements and corresponding indices along a dimension.
1031 ///
1032 /// # Arguments
1033 ///
1034 /// * `tensor` - The tensor to get the minimum elements and indices of.
1035 /// * `dim` - The dimension to get the minimum elements and indices along.
1036 ///
1037 /// # Returns
1038 ///
1039 /// The minimum elements and corresponding indices along the dimension.
1040 fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1041 let indices = B::int_argmin(tensor.clone(), dim);
1042 let values = B::int_gather(dim, tensor, indices.clone());
1043
1044 (values, indices)
1045 }
1046
1047 /// Returns a new tensor with absolute values.
1048 ///
1049 /// # Arguments
1050 ///
1051 /// * `tensor` - The tensor to take absolute value of.
1052 ///
1053 /// # Returns
1054 ///
1055 /// A tensor with the same shape as `tensor` with absolute values.
1056 fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1057
1058 /// Transposes an int tensor.
1059 ///
1060 /// # Arguments
1061 ///
1062 /// * `tensor` - The tensor to transpose.
1063 ///
1064 /// # Returns
1065 ///
1066 /// The transposed tensor.
1067 fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1068 let ndims = tensor.shape().num_dims();
1069 Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1070 }
1071
1072 /// Swaps two dimensions of an int tensor.
1073 ///
1074 /// # Arguments
1075 ///
1076 /// * `tensor` - The tensor to swap the dimensions of.
1077 /// * `dim1` - The first dimension to swap.
1078 /// * `dim2` - The second dimension to swap.
1079 ///
1080 /// # Returns
1081 ///
1082 /// The tensor with the dimensions swapped.
1083 fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1084
1085 /// Permutes the dimensions of a tensor.
1086 ///
1087 /// # Arguments
1088 ///
1089 /// * `tensor` - The tensor to permute the dimensions of.
1090 /// * `axes` - The new order of the dimensions.
1091 /// # Returns
1092 ///
1093 /// The tensor with the dimensions permuted.
1094 fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1095
1096 /// Reverse the order of elements in a tensor along the given axes.
1097 ///
1098 /// # Arguments
1099 ///
1100 /// * `tensor` - The tensor to reverse.
1101 /// * `axes` - The axes to reverse.
1102 ///
1103 /// The tensor with the elements reversed.
1104 fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1105
1106 /// Creates a new int tensor with random values.
1107 ///
1108 /// # Arguments
1109 /// * `shape` - The shape of the tensor.
1110 /// * `distribution` - The distribution to sample from.
1111 /// * `device` - The device to create the tensor on.
1112 ///
1113 /// # Returns
1114 ///
1115 /// The tensor with the given shape and random values.
1116 fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;
1117
1118 /// Creates a new tensor with values from the given range with the given step size.
1119 ///
1120 /// # Arguments
1121 ///
1122 /// * `range` - The range of values.
1123 /// * `step` - The step size.
1124 /// * `device` - The device to create the tensor on.
1125 ///
1126 /// # Returns
1127 ///
1128 /// The tensor with the given values.
1129 fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {
1130 let value = range
1131 .step_by(step)
1132 .map(|i| i.elem())
1133 .collect::<Vec<IntElem<B>>>();
1134 let shape = Shape::new([value.len()]);
1135 let data = TensorData::new(value, shape);
1136 B::int_from_data(data, device)
1137 }
1138
1139 /// Creates a new tensor with values from the given range.
1140 ///
1141 /// # Arguments
1142 ///
1143 /// * `range` - The range of values.
1144 /// * `device` - The device to create the tensor on.
1145 ///
1146 /// # Returns
1147 ///
1148 /// The tensor with the given values.
1149 ///
1150 /// # Remarks
1151 ///
1152 /// Uses `arange_step` with a step size of 1 under the hood.
1153 fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {
1154 Self::int_arange_step(range, 1, device)
1155 }
1156
1157 /// Tests if any element in the int `tensor` evaluates to True.
1158 ///
1159 /// # Arguments
1160 ///
1161 /// * `tensor` - The tensor to test.
1162 ///
1163 /// # Returns
1164 ///
1165 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1166 fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {
1167 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1168 let bool_tensor = B::bool_not(bool_tensor);
1169 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1170 B::int_greater_elem(sum, 0.elem())
1171 }
1172
1173 /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1174 ///
1175 /// # Arguments
1176 ///
1177 /// * `tensor` - The tensor to test.
1178 /// * `dim` - The axis along which to test.
1179 ///
1180 /// # Returns
1181 ///
1182 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1183 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1184 /// evaluates to True, False otherwise.
1185 fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1186 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1187 let bool_tensor = B::bool_not(bool_tensor);
1188 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1189 B::int_greater_elem(sum, 0.elem())
1190 }
1191
1192 /// Tests if all elements in the int `tensor` evaluate to True.
1193 ///
1194 /// # Arguments
1195 ///
1196 /// * `tensor` - The tensor to test.
1197 ///
1198 /// # Returns
1199 ///
1200 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1201 /// evaluate to True, False otherwise.
1202 fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {
1203 let num_elems = tensor.shape().num_elements();
1204 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1205 let bool_tensor = B::bool_not(bool_tensor);
1206 let sum = B::int_sum(B::bool_into_int(bool_tensor));
1207 B::int_equal_elem(sum, (num_elems as i32).elem())
1208 }
1209
1210 /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1211 ///
1212 /// # Arguments
1213 ///
1214 /// * `tensor` - The tensor to test.
1215 /// * `dim` - The axis along which to test.
1216 ///
1217 /// # Returns
1218 ///
1219 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1220 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1221 /// evaluates to True, False otherwise.
1222 fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {
1223 let num_elems = tensor.shape().dims[dim];
1224 let bool_tensor = B::int_equal_elem(tensor, 0.elem());
1225 let bool_tensor = B::bool_not(bool_tensor);
1226 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);
1227 B::int_equal_elem(sum, (num_elems as i32).elem())
1228 }
1229
1230 /// Returns the signs of the int `tensor`.
1231 ///
1232 /// # Arguments
1233 ///
1234 /// * `tensor` - The tensor to extract the signs from.
1235 ///
1236 /// # Returns
1237 ///
1238 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1239 fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1240 let dtype = tensor.dtype();
1241 let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor), dtype.into());
1242 let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem());
1243 let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem());
1244
1245 let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1246 result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem());
1247 result
1248 }
1249
1250 /// Broadcasts the int `tensor` to the given `shape`.
1251 fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1252
1253 /// Sort the elements of the input `tensor` by value along a given dimension.
1254 ///
1255 /// This sort is unstable (i.e., may reorder equal elements).
1256 ///
1257 /// # Arguments
1258 ///
1259 /// * `tensor` - The input tensor.
1260 /// * `dim` - The axis along which to sort.
1261 /// * `descending` - The sorting order.
1262 ///
1263 /// # Returns
1264 ///
1265 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1266 fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1267 sort::<B, Int>(tensor, dim, descending)
1268 }
1269
1270 /// Sort the elements of the input `tensor` by value along a given dimension.
1271 ///
1272 /// This sort is unstable (i.e., may reorder equal elements).
1273 ///
1274 /// # Arguments
1275 ///
1276 /// * `tensor` - The input tensor.
1277 /// * `dim` - The axis along which to sort.
1278 ///
1279 /// # Returns
1280 ///
1281 /// A tensor with the same shape as the input tensor and corresponding indices, where
1282 /// the elements are sorted by value and the indices map back to the original input tensor.
1283 fn int_sort_with_indices(
1284 tensor: IntTensor<B>,
1285 dim: usize,
1286 descending: bool,
1287 ) -> (IntTensor<B>, IntTensor<B>) {
1288 sort_with_indices::<B, Int>(tensor, dim, descending)
1289 }
1290
1291 /// Returns the indices that sort the elements of the input `tensor` by value
1292 /// along a given dimension.
1293 ///
1294 /// This sort is unstable (i.e., may reorder equal elements).
1295 ///
1296 /// # Arguments
1297 ///
1298 /// * `tensor` - The input tensor.
1299 /// * `dim` - The axis along which to sort.
1300 /// * `descending` - The sorting order.
1301 ///
1302 /// # Returns
1303 ///
1304 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1305 fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1306 argsort::<B, Int>(tensor, dim, descending)
1307 }
1308
1309 /// Bitwise AND operation for Int Tensors
1310 fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1311
1312 /// Bitwise AND operation for Int Tensors with a scalar
1313 fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1314
1315 /// Bitwise OR operation for Int Tensors
1316 fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1317
1318 /// Bitwise OR operation for Int Tensors with a scalar
1319 fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1320
1321 /// Bitwise XOR operation for Int Tensors
1322 fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1323
1324 /// Bitwise XOR operation for Int Tensors with a scalar
1325 fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1326
1327 /// Bitwise NOT operation for Int Tensors
1328 fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1329
1330 /// Bitwise left shift operation for Int Tensors
1331 fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1332
1333 /// Bitwise left shift operation for Int Tensors with a scalar
1334 fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1335
1336 /// Bitwise right shift operation for Int Tensors
1337 fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1338
1339 /// Bitwise right shift operation for Int Tensors with a scalar
1340 fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
1341
1342 /// Converts a tensor to another integer data type.
1343 ///
1344 /// # Arguments
1345 ///
1346 /// * `tensor` - The tensor to convert.
1347 /// * `dtype` - The target data type.
1348 ///
1349 /// # Returns
1350 ///
1351 /// A tensor with the same values as `tensor` but in the target integer data type.
1352 fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1353
1354 /// Unfold windows along a dimension.
1355 ///
1356 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1357 /// where windows are advanced by `step` at each index.
1358 ///
1359 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1360 ///
1361 /// # Arguments
1362 ///
1363 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1364 /// * `dim` - the selected dim.
1365 /// * `size` - the size of each unfolded window.
1366 /// * `step` - the step between each window.
1367 ///
1368 /// # Returns
1369 ///
1370 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1371 fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1372}