burn_backend/backend/ops/int_tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::sort::{argsort, sort, sort_with_indices};
4use crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};
5use crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};
6use crate::{ExecutionError, Scalar, get_device_settings};
7use alloc::vec::Vec;
8use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
9use core::ops::Range;
10
11/// Int Tensor API for basic and numeric operations, see
12#[cfg_attr(doc, doc = crate::doc_tensor!())]
13#[cfg_attr(not(doc), doc = "`Tensor`")]
14/// for documentation on each function.
15pub trait IntTensorOps<B: Backend> {
16 /// Creates a new int tensor.
17 ///
18 /// # Arguments
19 ///
20 /// * `shape` - The shape of the tensor.
21 /// * `device` - The device to create the tensor on.
22 /// * `dtype` - The target data type.
23 ///
24 /// # Returns
25 ///
26 /// The integer tensor with the given shape.
27 fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;
28
29 /// Converts the tensor to a data structure.
30 ///
31 /// # Arguments
32 ///
33 /// * `tensor` - The tensor.
34 ///
35 /// # Returns
36 ///
37 /// The data structure with the tensor's data.
38 fn int_into_data(
39 tensor: IntTensor<B>,
40 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
41
42 /// Creates a tensor from the data structure.
43 ///
44 /// # Arguments
45 ///
46 /// * `data` - The data structure.
47 /// * `device` - The device to create the tensor on.
48 ///
49 /// # Returns
50 ///
51 /// The tensor with the data.
52 fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;
53
54 /// Gets the device of the tensor.
55 ///
56 /// # Arguments
57 ///
58 /// * `tensor` - The tensor.
59 ///
60 /// # Returns
61 ///
62 /// The device of the tensor.
63 fn int_device(tensor: &IntTensor<B>) -> Device<B>;
64
65 /// Moves the tensor to the given device.
66 fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;
67
68 /// Reshapes the tensor.
69 ///
70 /// # Arguments
71 ///
72 /// * `tensor` - The tensor.
73 /// * `shape` - The new shape.
74 ///
75 /// # Returns
76 ///
77 /// The tensor with the new shape.
78 fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
79
80 /// Gets the element at the given indices.
81 ///
82 /// # Arguments
83 ///
84 /// * `tensor` - The tensor.
85 /// * `slices` - The slices specifying ranges and steps for each dimension.
86 ///
87 /// # Returns
88 ///
89 /// The elements at the given indices.
90 ///
91 /// # Note
92 ///
93 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
94 /// be passed to this method. Backend implementations do not need to handle empty slices.
95 fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;
96
97 /// Sets the values in the tensor for the given ranges.
98 ///
99 /// # Arguments
100 ///
101 /// * `tensor` - The tensor.
102 /// * `ranges` - The ranges to set the values for.
103 ///
104 /// # Returns
105 ///
106 /// The tensor with the values set for the given ranges.
107 ///
108 /// # Note
109 ///
110 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
111 /// high-level tensor API and will not be passed to this method. Backend implementations do
112 /// not need to handle empty slice assignments.
113 fn int_slice_assign(
114 tensor: IntTensor<B>,
115 slices: &[Slice],
116 value: IntTensor<B>,
117 ) -> IntTensor<B>;
118
119 /// Converts int tensor to float tensor.
120 ///
121 /// # Arguments
122 ///
123 /// * `tensor` - The tensor.
124 /// * `out_dtype` - The output tensor dtype.
125 ///
126 /// # Returns
127 ///
128 /// The int tensor with the same data as the float tensor.
129 fn int_into_float(tensor: IntTensor<B>, out_dtype: FloatDType) -> FloatTensor<B>;
130
131 /// Fills the tensor with values from the value tensor if the mask is true at the given
132 /// indices.
133 ///
134 /// # Arguments
135 ///
136 /// * `tensor` - The tensor.
137 /// * `mask` - The mask.
138 /// * `value` - The value tensor.
139 ///
140 /// # Returns
141 ///
142 /// The tensor with the values filled.
143 fn int_mask_where(
144 tensor: IntTensor<B>,
145 mask: BoolTensor<B>,
146 value: IntTensor<B>,
147 ) -> IntTensor<B>;
148
149 /// Fills the tensor with the given value if the mask is true at the given indices.
150 ///
151 /// # Arguments
152 ///
153 /// * `tensor` - The tensor.
154 /// * `mask` - The mask.
155 /// * `value` - The value.
156 ///
157 /// # Returns
158 ///
159 /// The tensor with the values filled.
160 fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: Scalar) -> IntTensor<B>;
161
162 /// Gather elements from the tensor at the given indices.
163 ///
164 /// # Arguments
165 ///
166 /// * `dim` - The dimension to gather from.
167 /// * `tensor` - The tensor.
168 /// * `indices` - The indices.
169 fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;
170
171 /// Scatter a given value to the tensor at the given indices using sum reduction.
172 ///
173 /// # Arguments
174 ///
175 /// * `dim` - The dimension to scatter to.
176 /// * `tensor` - The tensor.
177 /// * `indices` - The indices.
178 /// * `value` - The value.
179 ///
180 /// # Returns
181 ///
182 /// The tensor with the values scattered.
183 fn int_scatter_add(
184 dim: usize,
185 tensor: IntTensor<B>,
186 indices: IntTensor<B>,
187 value: IntTensor<B>,
188 ) -> IntTensor<B>;
189
190 /// Multi-dimensional scatter for int tensors.
191 fn int_scatter_nd(
192 _data: IntTensor<B>,
193 _indices: IntTensor<B>,
194 _values: IntTensor<B>,
195 _reduction: crate::tensor::IndexingUpdateOp,
196 ) -> IntTensor<B> {
197 unimplemented!("int_scatter_nd is not implemented for this backend")
198 }
199
200 /// Multi-dimensional gather for int tensors.
201 fn int_gather_nd(_data: IntTensor<B>, _indices: IntTensor<B>) -> IntTensor<B> {
202 unimplemented!("int_gather_nd is not implemented for this backend")
203 }
204
205 /// Select tensor elements along the given dimension corresponding to the given indices.
206 ///
207 /// # Arguments
208 ///
209 /// * `tensor` - The tensor.
210 /// * `dim` - The dimension to select from.
211 /// * `indices` - The indices.
212 ///
213 /// # Returns
214 ///
215 /// The tensor with the selected elements.
216 fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;
217
218 /// Assign the selected elements along the given dimension corresponding to the given indices
219 /// to the given value using sum reduction.
220 ///
221 /// # Arguments
222 ///
223 /// * `tensor` - The tensor.
224 /// * `dim` - The dimension to select from.
225 /// * `indices` - The indices.
226 /// * `value` - The value.
227 ///
228 /// # Returns
229 ///
230 /// The tensor with the selected elements assigned to the given value.
231 fn int_select_add(
232 tensor: IntTensor<B>,
233 dim: usize,
234 indices: IntTensor<B>,
235 value: IntTensor<B>,
236 ) -> IntTensor<B>;
237
238 /// Repeats the tensor along the given dimension the given number of times.
239 ///
240 /// # Arguments
241 ///
242 /// * `tensor` - The tensor.
243 /// * `dim` - The dimension to repeat.
244 /// * `times` - The number of times to repeat.
245 ///
246 /// # Returns
247 ///
248 /// The tensor with the given dimension repeated the given number of times.
249 fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
250 repeat_with_slice_assign::<B, Int>(tensor, dim, times)
251 }
252
253 /// Concatenates the given tensors along the given dimension.
254 ///
255 /// # Arguments
256 ///
257 /// * `tensors` - The tensors.
258 /// * `dim` - The dimension to concatenate along.
259 ///
260 /// # Returns
261 ///
262 /// The concatenated tensor.
263 ///
264 /// # Note
265 ///
266 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
267 /// high-level tensor API and will not be passed to this method. Backend implementations do
268 /// not need to handle empty tensors.
269 fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
270 cat_with_slice_assign::<B, Int>(tensors, dim)
271 }
272
273 /// Element-wise equality comparison.
274 ///
275 /// # Arguments
276 ///
277 /// * `lhs` - The left-hand side tensor.
278 /// * `rhs` - The right-hand side tensor.
279 /// * `out_dtype` - The output tensor dtype.
280 ///
281 /// # Returns
282 ///
283 /// The boolean tensor with the result of the comparison.
284 fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
285
286 /// Element-wise non-equality comparison.
287 ///
288 /// # Arguments
289 ///
290 /// * `lhs` - The left-hand side tensor.
291 /// * `rhs` - The right-hand side tensor.
292 /// * `out_dtype` - The output tensor dtype.
293 ///
294 /// # Returns
295 ///
296 /// The boolean tensor with the result of the comparison.
297 fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
298 let equal_tensor = B::int_equal(lhs, rhs, out_dtype);
299 B::bool_not(equal_tensor)
300 }
301
302 /// Element-wise equality comparison with a scalar.
303 ///
304 /// # Arguments
305 ///
306 /// * `lhs` - The left-hand side tensor.
307 /// * `rhs` - The right-hand side scalar.
308 /// * `out_dtype` - The output tensor dtype.
309 ///
310 /// # Returns
311 ///
312 /// The boolean tensor with the result of the comparison.
313 fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
314
315 /// Element-wise non-equality comparison with a scalar.
316 ///
317 /// # Arguments
318 ///
319 /// * `lhs` - The left-hand side tensor.
320 /// * `rhs` - The right-hand side scalar.
321 /// * `out_dtype` - The output tensor dtype.
322 ///
323 /// # Returns
324 ///
325 /// The boolean tensor with the result of the comparison.
326 fn int_not_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B> {
327 let equal_tensor = B::int_equal_elem(lhs, rhs, out_dtype);
328 B::bool_not(equal_tensor)
329 }
330
331 /// Element-wise greater than comparison.
332 ///
333 /// # Arguments
334 ///
335 /// * `lhs` - The left-hand side tensor.
336 /// * `rhs` - The right-hand side tensor.
337 /// * `out_dtype` - The output tensor dtype.
338 ///
339 /// # Returns
340 ///
341 /// The boolean tensor with the result of the comparison.
342 fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
343
344 /// Element-wise greater than comparison with a scalar.
345 ///
346 /// # Arguments
347 ///
348 /// * `lhs` - The left-hand side tensor.
349 /// * `rhs` - The right-hand side scalar.
350 /// * `out_dtype` - The output tensor dtype.
351 ///
352 /// # Returns
353 ///
354 /// The boolean tensor with the result of the comparison.
355 fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
356
357 /// Element-wise greater than or equal comparison.
358 ///
359 /// # Arguments
360 ///
361 /// * `lhs` - The left-hand side tensor.
362 /// * `rhs` - The right-hand side tensor.
363 /// * `out_dtype` - The output tensor dtype.
364 ///
365 /// # Returns
366 ///
367 /// The boolean tensor with the result of the comparison.
368 fn int_greater_equal(
369 lhs: IntTensor<B>,
370 rhs: IntTensor<B>,
371 out_dtype: BoolDType,
372 ) -> BoolTensor<B>;
373
374 /// Element-wise greater than or equal comparison with a scalar.
375 ///
376 /// # Arguments
377 ///
378 /// * `lhs` - The left-hand side tensor.
379 /// * `rhs` - The right-hand side scalar.
380 /// * `out_dtype` - The output tensor dtype.
381 ///
382 /// # Returns
383 ///
384 /// The boolean tensor with the result of the comparison.
385 fn int_greater_equal_elem(
386 lhs: IntTensor<B>,
387 rhs: Scalar,
388 out_dtype: BoolDType,
389 ) -> BoolTensor<B>;
390
391 /// Element-wise less than comparison.
392 ///
393 /// # Arguments
394 ///
395 /// * `lhs` - The left-hand side tensor.
396 /// * `rhs` - The right-hand side tensor.
397 /// * `out_dtype` - The output tensor dtype.
398 ///
399 /// # Returns
400 ///
401 /// The boolean tensor with the result of the comparison.
402 fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B>;
403
404 /// Element-wise less than comparison with a scalar.
405 ///
406 /// # Arguments
407 ///
408 /// * `lhs` - The left-hand side tensor.
409 /// * `rhs` - The right-hand side scalar.
410 /// * `out_dtype` - The output tensor dtype.
411 ///
412 /// # Returns
413 ///
414 /// The boolean tensor with the result of the comparison.
415 fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
416
417 /// Element-wise less than or equal comparison.
418 ///
419 /// # Arguments
420 ///
421 /// * `lhs` - The left-hand side tensor.
422 /// * `rhs` - The right-hand side tensor.
423 /// * `out_dtype` - The output tensor dtype.
424 ///
425 /// # Returns
426 ///
427 /// The boolean tensor with the result of the comparison.
428 fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>, out_dtype: BoolDType)
429 -> BoolTensor<B>;
430
431 /// Element-wise less than or equal comparison with a scalar.
432 ///
433 /// # Arguments
434 ///
435 /// * `lhs` - The left-hand side tensor.
436 /// * `rhs` - The right-hand side scalar.
437 /// * `out_dtype` - The output tensor dtype.
438 ///
439 /// # Returns
440 ///
441 /// The boolean tensor with the result of the comparison.
442 fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
443
444 // ==== NUMERIC ==== //
445
446 /// Element-wise addition.
447 ///
448 /// # Arguments
449 ///
450 /// * `lhs` - The left-hand side tensor.
451 /// * `rhs` - The right-hand side tensor.
452 ///
453 /// # Returns
454 ///
455 /// The result of the addition.
456 fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
457
458 /// Element-wise addition with a scalar.
459 ///
460 /// # Arguments
461 ///
462 /// * `lhs` - The left-hand side tensor.
463 /// * `rhs` - The right-hand side scalar.
464 ///
465 /// # Returns
466 ///
467 /// The result of the addition.
468 fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
469
470 /// Element-wise power with a IntTensor.
471 ///
472 /// # Arguments
473 ///
474 /// * `lhs` - The left-hand side IntTensor.
475 /// * `rhs` - The right-hand side IntTensor.
476 ///
477 /// # Returns
478 ///
479 /// The elements of `lhs` raised to the power of the elements of `rhs`.
480 fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
481 let dtype = lhs.dtype();
482 let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
483 B::float_into_int(
484 B::float_powi(B::int_into_float(lhs, float_dtype), rhs),
485 dtype.into(),
486 )
487 }
488
489 /// Element-wise power with a scalar.
490 ///
491 /// # Backend Implementors Note
492 ///
493 /// A number of common exponent cases can be implemented with operations
494 /// which are much cheaper than generic exponentiation.
495 ///
496 /// This (`Backend` impl overridable) operation handles generic optimizations
497 /// for several common integer exponent cases; and then dispatches to
498 /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]
499 /// operation to handle the generic case.
500 ///
501 /// # Arguments
502 ///
503 /// * `lhs` - The left-hand side tensor.
504 /// * `rhs` - The right-hand side scalar.
505 ///
506 /// # Returns
507 ///
508 /// The elements of `lhs` raised to the value of `rhs`.
509 fn int_powi_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
510 let exp = rhs.elem::<i32>();
511 match exp {
512 0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),
513 1 => lhs,
514 2 => Self::int_mul(lhs.clone(), lhs),
515 _ => Self::int_powi_scalar_impl(lhs, rhs),
516 }
517 }
518
519 /// Element-wise power with a scalar.
520 ///
521 /// # Backend Implementors Note
522 ///
523 /// This is the generic implementation of integer exponentiation
524 /// called by [`Self::int_powi_scalar`] in the fallback case.
525 ///
526 /// By default, this performs a relatively expensive conversion to float,
527 /// exponentiation in float, and conversion back to int.
528 /// This reduces the minimal operation set for `Backend`s,
529 /// at the cost of performance.
530 ///
531 /// This is a good target for specialized optimizations in `Backend` implementations.
532 ///
533 /// As a general rule, this should not be called directly.
534 ///
535 /// # Arguments
536 ///
537 /// * `lhs` - The left-hand side tensor.
538 /// * `rhs` - The right-hand side scalar.
539 ///
540 /// # Returns
541 ///
542 /// The elements of `lhs` raised to the value of `rhs`.
543 fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {
544 let dtype = lhs.dtype();
545 let float_dtype = get_device_settings::<B>(&B::int_device(&lhs)).float_dtype;
546 B::float_into_int(
547 B::float_powi_scalar_impl(B::int_into_float(lhs, float_dtype), rhs),
548 dtype.into(),
549 )
550 }
551
552 /// Clamps a tensor under a minimum value.
553 ///
554 /// # Arguments
555 ///
556 /// * `tensor` - The tensor to clamp.
557 /// * `min` - The minimum value.
558 ///
559 /// # Returns
560 ///
561 /// The clamped tensor.
562 fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {
563 let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
564 let mask = Self::int_lower_elem(tensor.clone(), min, dtype);
565 Self::int_mask_fill(tensor, mask, min)
566 }
567
568 /// Clamps a tensor over a maximum value.
569 ///
570 /// # Arguments
571 ///
572 /// * `tensor` - The tensor to clamp.
573 /// * `max` - The maximum value.
574 ///
575 /// # Returns
576 ///
577 /// The clamped tensor.
578 fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {
579 let dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
580 let mask = Self::int_greater_elem(tensor.clone(), max, dtype);
581 Self::int_mask_fill(tensor, mask, max)
582 }
583
584 /// Clamps a tensor between a minimum and maximum value.
585 ///
586 /// # Arguments
587 ///
588 /// * `tensor` - The tensor to clamp.
589 /// * `min` - The minimum value.
590 /// * `max` - The maximum value.
591 ///
592 /// # Returns
593 ///
594 /// The clamped tensor.
595 fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {
596 Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)
597 }
598
599 /// Element-wise subtraction.
600 ///
601 /// # Arguments
602 ///
603 /// * `lhs` - The left-hand side tensor.
604 /// * `rhs` - The right-hand side tensor.
605 ///
606 /// # Returns
607 ///
608 /// The result of the subtraction.
609 fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
610
611 /// Element-wise subtraction with a scalar.
612 ///
613 /// # Arguments
614 ///
615 /// * `lhs` - The left-hand side tensor.
616 /// * `rhs` - The right-hand side scalar.
617 ///
618 /// # Returns
619 ///
620 /// The result of the subtraction.
621 fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
622
623 /// Element-wise multiplication.
624 ///
625 /// # Arguments
626 ///
627 /// * `lhs` - The left-hand side tensor.
628 /// * `rhs` - The right-hand side tensor.
629 ///
630 /// # Returns
631 ///
632 /// The result of the multiplication.
633 fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
634
635 /// Element-wise multiplication with a scalar.
636 ///
637 /// # Arguments
638 ///
639 /// * `lhs` - The left-hand side tensor.
640 /// * `rhs` - The right-hand side scalar.
641 ///
642 /// # Returns
643 ///
644 /// The result of the multiplication.
645 fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
646
647 /// Element-wise division.
648 ///
649 /// # Arguments
650 ///
651 /// * `lhs` - The left-hand side tensor.
652 /// * `rhs` - The right-hand side tensor.
653 ///
654 /// # Returns
655 ///
656 /// The result of the division.
657 fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
658
659 /// Element-wise division with a scalar.
660 ///
661 /// # Arguments
662 ///
663 /// * `lhs` - The left-hand side tensor.
664 /// * `rhs` - The right-hand side scalar.
665 ///
666 /// # Returns
667 ///
668 /// The result of the division.
669 fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
670
671 /// Element-wise modulus.
672 ///
673 /// # Arguments
674 /// * `lhs` - The left-hand side tensor.
675 /// * `rhs` - The right-hand side scalar.
676 ///
677 /// # Returns
678 ///
679 /// The result of applying the modulus of the scalar to the tensor.
680 fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
681
682 /// Element-wise modulus with a scalar.
683 ///
684 /// # Arguments
685 /// * `lhs` - The left-hand side tensor.
686 /// * `rhs` - The right-hand side scalar.
687 ///
688 /// # Returns
689 ///
690 /// The result of applying the modulus of the scalar to the tensor.
691 fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
692
693 /// Multiplies two tensors together using matrix multiplication.
694 ///
695 /// # Arguments
696 ///
697 /// * `lhs` - The left-hand side tensor.
698 /// * `rhs` - The right-hand side tensor.
699 ///
700 /// # Returns
701 ///
702 /// The result of multiplying the two tensors together using matrix multiplication.
703 fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
704
705 /// Element-wise negation.
706 ///
707 /// # Arguments
708 ///
709 /// * `tensor` - The tensor to negate.
710 ///
711 /// # Returns
712 ///
713 /// The negated tensor.
714 fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
715 Self::int_mul_scalar(tensor, (-1).into())
716 }
717
718 /// Creates a tensor of zeros.
719 ///
720 /// # Arguments
721 ///
722 /// * `shape` - The shape of the tensor.
723 /// * `device` - The device to create the tensor on.
724 /// * `dtype` - The target data type.
725 ///
726 /// # Returns
727 ///
728 /// The tensor of zeros.
729 fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
730 Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
731 }
732
733 /// Creates a tensor of ones.
734 ///
735 /// # Arguments
736 ///
737 /// * `shape` - The shape of the tensor.
738 /// * `device` - The device to create the tensor on.
739 /// * `dtype` - The target data type.
740 ///
741 /// # Returns
742 ///
743 /// The tensor of ones.
744 fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
745 Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
746 }
747
748 /// Creates a tensor filled with given value.
749 ///
750 /// # Arguments
751 ///
752 /// * `shape` - The shape of the tensor.
753 /// * `fill_value` - The value with which to fill the tensor.
754 /// * `device` - The device to create the tensor on.
755 /// * `dtype` - The target data type.
756 ///
757 /// # Returns
758 ///
759 /// The tensor filled with given value
760 fn int_full(
761 shape: Shape,
762 fill_value: Scalar,
763 device: &Device<B>,
764 dtype: IntDType,
765 ) -> IntTensor<B> {
766 Self::int_from_data(
767 TensorData::full_dtype(shape, fill_value, dtype.into()),
768 device,
769 )
770 }
771
772 /// Sums all elements in the tensor.
773 ///
774 /// # Arguments
775 ///
776 /// * `tensor` - The tensor to sum.
777 ///
778 /// # Returns
779 ///
780 /// The sum of all elements in the tensor.
781 fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;
782
783 /// Sums all elements in the tensor along a dimension.
784 ///
785 /// # Arguments
786 ///
787 /// * `tensor` - The tensor to sum.
788 /// * `dim` - The dimension to sum along.
789 ///
790 /// # Returns
791 ///
792 /// The sum of all elements in the tensor along the dimension.
793 fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
794
795 /// Computes the product of all elements in the tensor.
796 ///
797 /// # Arguments
798 ///
799 /// * `tensor` - The tensor to compute the product of.
800 ///
801 /// # Returns
802 ///
803 /// The product of all elements in the tensor.
804 fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;
805
806 /// Computes the product of all elements in the tensor along a dimension.
807 ///
808 /// # Arguments
809 ///
810 /// * `tensor` - The tensor to compute the product of.
811 /// * `dim` - The dimension to compute the product along.
812 ///
813 /// # Returns
814 ///
815 /// The product of all elements in the tensor along the dimension.
816 fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
817
818 /// Computes the mean of all elements in the tensor.
819 ///
820 /// # Arguments
821 ///
822 /// * `tensor` - The tensor to compute the mean of.
823 ///
824 /// # Returns
825 ///
826 /// The mean of all elements in the tensor.
827 fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
828 let num_elems = tensor.shape().num_elements() as i64;
829 B::int_div_scalar(B::int_sum(tensor), num_elems.into())
830 }
831
832 /// Computes the mean of all elements in the tensor along a dimension.
833 ///
834 /// # Arguments
835 ///
836 /// * `tensor` - The tensor to compute the mean of.
837 ///
838 /// # Returns
839 ///
840 /// The mean of all elements in the tensor along the dimension.
841 fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
842
843 /// Computes the cumulative sum of elements along a dimension.
844 ///
845 /// # Arguments
846 ///
847 /// * `tensor` - The tensor to compute the cumulative sum of.
848 /// * `dim` - The dimension along which to compute the cumulative sum.
849 ///
850 /// # Returns
851 ///
852 /// A tensor with the same shape where each element is the cumulative sum
853 /// of all elements up to and including that position along the dimension.
854 fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
855
856 /// Computes the cumulative product of elements along a dimension.
857 ///
858 /// # Arguments
859 ///
860 /// * `tensor` - The tensor to compute the cumulative product of.
861 /// * `dim` - The dimension along which to compute the cumulative product.
862 ///
863 /// # Returns
864 ///
865 /// A tensor with the same shape where each element is the cumulative product
866 /// of all elements up to and including that position along the dimension.
867 fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
868
869 /// Computes the cumulative minimum of elements along a dimension.
870 ///
871 /// # Arguments
872 ///
873 /// * `tensor` - The tensor to compute the cumulative minimum of.
874 /// * `dim` - The dimension along which to compute the cumulative minimum.
875 ///
876 /// # Returns
877 ///
878 /// A tensor with the same shape where each element is the minimum
879 /// of all elements up to and including that position along the dimension.
880 fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
881
882 /// Computes the cumulative maximum of elements along a dimension.
883 ///
884 /// # Arguments
885 ///
886 /// * `tensor` - The tensor to compute the cumulative maximum of.
887 /// * `dim` - The dimension along which to compute the cumulative maximum.
888 ///
889 /// # Returns
890 ///
891 /// A tensor with the same shape where each element is the maximum
892 /// of all elements up to and including that position along the dimension.
893 fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
894
895 /// Gets the indices of the maximum elements along a dimension.
896 ///
897 /// # Arguments
898 ///
899 /// * `tensor` - The tensor to get the maximum indices of.
900 /// * `dim` - The dimension to get the maximum indices along.
901 ///
902 /// # Returns
903 ///
904 /// The indices of the maximum elements along the dimension.
905 fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
906
907 /// Gets the indices of the k maximum elements along a dimension.
908 /// If two elements share the same value, it will be ordered by the lowest
909 /// coordinate
910 ///
911 /// # Arguments
912 ///
913 /// * `tensor` - The tensor to get the maximum indices of.
914 /// * `dim` - The dimension to get the maximum indices along.
915 /// * `k` - number of maximum elements.
916 ///
917 /// # Returns
918 ///
919 /// The indices of the maximum elements along the dimension.
920 fn int_argtopk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B>;
921
922 /// Gets the values of the k maximum elements along a dimension.
923 /// # Arguments
924 ///
925 /// * `tensor` - The tensor to get the maximum values of.
926 /// * `dim` - The dimension to get the maximum values along.
927 /// * `k` - number of maximum elements.
928 ///
929 /// # Returns
930 ///
931 /// The values of the maximum elements along the dimension.
932 fn int_topk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B>;
933
934 /// Gets the indices of the minimum elements along a dimension.
935 ///
936 /// # Arguments
937 ///
938 /// * `tensor` - The tensor to get the minimum indices of.
939 /// * `dim` - The dimension to get the minimum indices along.
940 ///
941 /// # Returns
942 ///
943 /// The indices of the minimum elements along the dimension.
944 fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;
945
946 /// Gets the maximum element in the tensor.
947 ///
948 /// # Arguments
949 ///
950 /// * `tensor` - The tensor to get the maximum element of.
951 ///
952 /// # Returns
953 ///
954 /// The maximum element in the tensor.
955 fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {
956 let shape = tensor.shape();
957 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
958
959 B::int_max_dim(tensor, 0)
960 }
961
962 /// Gets the maximum element in the tensor along a dimension.
963 ///
964 /// # Arguments
965 ///
966 /// * `tensor` - The tensor to get the maximum element of.
967 /// * `dim` - The dimension to get the maximum element along.
968 ///
969 /// # Returns
970 ///
971 /// The maximum element in the tensor along the dimension.
972 fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
973 let index = B::int_argmax(tensor.clone(), dim);
974 B::int_gather(dim, tensor, index)
975 }
976
977 /// Gets the maximum elements and corresponding indices along a dimension.
978 ///
979 /// # Arguments
980 ///
981 /// * `tensor` - The tensor to get the maximum elements and indices of.
982 /// * `dim` - The dimension to get the maximum elements and indices along.
983 ///
984 /// # Returns
985 ///
986 /// The maximum elements and corresponding indices along the dimension.
987 fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
988 let index = B::int_argmax(tensor.clone(), dim);
989 let values = B::int_gather(dim, tensor, index.clone());
990
991 (values, index)
992 }
993
994 /// Gets the maximum absolute element in the tensor.
995 ///
996 /// # Arguments
997 ///
998 /// * `tensor` - The tensor to get the maximum element of.
999 ///
1000 /// # Returns
1001 ///
1002 /// The maximum element in the tensor.
1003 fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {
1004 let shape = tensor.shape();
1005 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1006
1007 B::int_max_abs_dim(tensor, 0)
1008 }
1009
1010 /// Gets the maximum absolute element in the tensor along a dimension.
1011 ///
1012 /// # Arguments
1013 ///
1014 /// * `tensor` - The tensor to get the maximum element of.
1015 /// * `dim` - The dimension to get the maximum element along.
1016 ///
1017 /// # Returns
1018 ///
1019 /// The maximum element in the tensor along the dimension.
1020 fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1021 B::int_max_dim(B::int_abs(tensor), dim)
1022 }
1023
1024 /// Gets the minimum element in the tensor.
1025 ///
1026 /// # Arguments
1027 ///
1028 /// * `tensor` - The tensor to get the minimum element of.
1029 ///
1030 /// # Returns
1031 ///
1032 /// The minimum element in the tensor.
1033 fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {
1034 let shape = tensor.shape();
1035 let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
1036
1037 B::int_min_dim(tensor, 0)
1038 }
1039
1040 /// Gets the minimum elements in the tensor along a dimension.
1041 ///
1042 /// # Arguments
1043 ///
1044 /// * `tensor` - The tensor to get the minimum element of.
1045 /// * `dim` - The dimension to get the minimum element along.
1046 ///
1047 /// # Returns
1048 ///
1049 /// The minimum element in the tensor along the dimension.
1050 fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
1051 let index = B::int_argmin(tensor.clone(), dim);
1052 B::int_gather(dim, tensor, index)
1053 }
1054
1055 /// Gets the minimum elements and corresponding indices along a dimension.
1056 ///
1057 /// # Arguments
1058 ///
1059 /// * `tensor` - The tensor to get the minimum elements and indices of.
1060 /// * `dim` - The dimension to get the minimum elements and indices along.
1061 ///
1062 /// # Returns
1063 ///
1064 /// The minimum elements and corresponding indices along the dimension.
1065 fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {
1066 let indices = B::int_argmin(tensor.clone(), dim);
1067 let values = B::int_gather(dim, tensor, indices.clone());
1068
1069 (values, indices)
1070 }
1071
1072 /// Returns a new tensor with absolute values.
1073 ///
1074 /// # Arguments
1075 ///
1076 /// * `tensor` - The tensor to take absolute value of.
1077 ///
1078 /// # Returns
1079 ///
1080 /// A tensor with the same shape as `tensor` with absolute values.
1081 fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;
1082
1083 /// Transposes an int tensor.
1084 ///
1085 /// # Arguments
1086 ///
1087 /// * `tensor` - The tensor to transpose.
1088 ///
1089 /// # Returns
1090 ///
1091 /// The transposed tensor.
1092 fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {
1093 let ndims = tensor.shape().num_dims();
1094 Self::int_swap_dims(tensor, ndims - 2, ndims - 1)
1095 }
1096
1097 /// Swaps two dimensions of an int tensor.
1098 ///
1099 /// # Arguments
1100 ///
1101 /// * `tensor` - The tensor to swap the dimensions of.
1102 /// * `dim1` - The first dimension to swap.
1103 /// * `dim2` - The second dimension to swap.
1104 ///
1105 /// # Returns
1106 ///
1107 /// The tensor with the dimensions swapped.
1108 fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;
1109
1110 /// Permutes the dimensions of a tensor.
1111 ///
1112 /// # Arguments
1113 ///
1114 /// * `tensor` - The tensor to permute the dimensions of.
1115 /// * `axes` - The new order of the dimensions.
1116 /// # Returns
1117 ///
1118 /// The tensor with the dimensions permuted.
1119 fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1120
1121 /// Reverse the order of elements in a tensor along the given axes.
1122 ///
1123 /// # Arguments
1124 ///
1125 /// * `tensor` - The tensor to reverse.
1126 /// * `axes` - The axes to reverse.
1127 ///
1128 /// The tensor with the elements reversed.
1129 fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;
1130
1131 /// Creates a new int tensor with random values.
1132 ///
1133 /// # Arguments
1134 /// * `shape` - The shape of the tensor.
1135 /// * `distribution` - The distribution to sample from.
1136 /// * `device` - The device to create the tensor on.
1137 /// * `dtype` - The target data type.
1138 ///
1139 /// # Returns
1140 ///
1141 /// The tensor with the given shape and random values.
1142 fn int_random(
1143 shape: Shape,
1144 distribution: Distribution,
1145 device: &Device<B>,
1146 dtype: IntDType,
1147 ) -> IntTensor<B>;
1148
1149 /// Creates a new tensor with values from the given range with the given step size.
1150 ///
1151 /// # Arguments
1152 ///
1153 /// * `range` - The range of values.
1154 /// * `step` - The step size.
1155 /// * `device` - The device to create the tensor on.
1156 /// * `dtype` - The target data type.
1157 ///
1158 /// # Returns
1159 ///
1160 /// The tensor with the given values.
1161 fn int_arange_step(
1162 range: Range<i64>,
1163 step: usize,
1164 device: &Device<B>,
1165 dtype: IntDType,
1166 ) -> IntTensor<B> {
1167 let value = range
1168 .step_by(step)
1169 .map(|i| i.elem())
1170 .collect::<Vec<IntElem<B>>>();
1171 let shape = Shape::new([value.len()]);
1172 let data = TensorData::new(value, shape).convert_dtype(dtype.into());
1173 B::int_from_data(data, device)
1174 }
1175
1176 /// Creates a new tensor with values from the given range.
1177 ///
1178 /// # Arguments
1179 ///
1180 /// * `range` - The range of values.
1181 /// * `device` - The device to create the tensor on.
1182 ///
1183 /// # Returns
1184 ///
1185 /// The tensor with the given values.
1186 ///
1187 /// # Remarks
1188 ///
1189 /// Uses `arange_step` with a step size of 1 under the hood.
1190 fn int_arange(range: Range<i64>, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {
1191 Self::int_arange_step(range, 1, device, dtype)
1192 }
1193
1194 /// Tests if any element in the int `tensor` evaluates to True.
1195 ///
1196 /// # Arguments
1197 ///
1198 /// * `tensor` - The tensor to test.
1199 ///
1200 /// # Returns
1201 ///
1202 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1203 fn int_any(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1204 let int_dtype = tensor.dtype();
1205 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1206 let bool_tensor = B::bool_not(bool_tensor);
1207 let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1208 B::int_greater_elem(sum, 0.into(), out_dtype)
1209 }
1210
1211 /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.
1212 ///
1213 /// # Arguments
1214 ///
1215 /// * `tensor` - The tensor to test.
1216 /// * `dim` - The axis along which to test.
1217 ///
1218 /// # Returns
1219 ///
1220 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1221 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1222 /// evaluates to True, False otherwise.
1223 fn int_any_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1224 let int_dtype = tensor.dtype();
1225 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1226 let bool_tensor = B::bool_not(bool_tensor);
1227 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1228 B::int_greater_elem(sum, 0.into(), out_dtype)
1229 }
1230
1231 /// Tests if all elements in the int `tensor` evaluate to True.
1232 ///
1233 /// # Arguments
1234 ///
1235 /// * `tensor` - The tensor to test.
1236 /// * `out_dtype` - The output tensor dtype.
1237 ///
1238 /// # Returns
1239 ///
1240 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1241 /// evaluate to True, False otherwise.
1242 fn int_all(tensor: IntTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1243 let int_dtype = tensor.dtype();
1244 let num_elems = tensor.shape().num_elements() as i64;
1245 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1246 let bool_tensor = B::bool_not(bool_tensor);
1247 let sum = B::int_sum(B::bool_into_int(bool_tensor, int_dtype.into()));
1248 B::int_equal_elem(sum, num_elems.into(), out_dtype)
1249 }
1250
1251 /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.
1252 ///
1253 /// # Arguments
1254 ///
1255 /// * `tensor` - The tensor to test.
1256 /// * `dim` - The axis along which to test.
1257 /// * `out_dtype` - The output tensor dtype.
1258 ///
1259 /// # Returns
1260 ///
1261 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1262 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1263 /// evaluates to True, False otherwise.
1264 fn int_all_dim(tensor: IntTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1265 let int_dtype = tensor.dtype();
1266 let num_elems = tensor.shape()[dim] as i64;
1267 let bool_tensor = B::int_equal_elem(tensor, 0.into(), out_dtype);
1268 let bool_tensor = B::bool_not(bool_tensor);
1269 let sum = B::int_sum_dim(B::bool_into_int(bool_tensor, int_dtype.into()), dim);
1270 B::int_equal_elem(sum, num_elems.into(), out_dtype)
1271 }
1272
1273 /// Returns the signs of the int `tensor`.
1274 ///
1275 /// # Arguments
1276 ///
1277 /// * `tensor` - The tensor to extract the signs from.
1278 ///
1279 /// # Returns
1280 ///
1281 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1282 fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {
1283 let dtype = tensor.dtype();
1284 let device = B::int_device(&tensor);
1285 let bool_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
1286 let zeros = B::int_zeros(tensor.shape(), &device, dtype.into());
1287 let less_than_zero = B::int_lower_elem(tensor.clone(), 0.into(), bool_dtype);
1288 let greater_than_zero = B::int_greater_elem(tensor, 0.into(), bool_dtype);
1289
1290 let mut result = B::int_mask_fill(zeros, less_than_zero, (-1).into());
1291 result = B::int_mask_fill(result, greater_than_zero, 1.into());
1292 result
1293 }
1294
1295 /// Broadcasts the int `tensor` to the given `shape`.
1296 fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;
1297
1298 /// Sort the elements of the input `tensor` by value along a given dimension.
1299 ///
1300 /// This sort is unstable (i.e., may reorder equal elements).
1301 ///
1302 /// # Arguments
1303 ///
1304 /// * `tensor` - The input tensor.
1305 /// * `dim` - The axis along which to sort.
1306 /// * `descending` - The sorting order.
1307 ///
1308 /// # Returns
1309 ///
1310 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1311 fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1312 sort::<B, Int>(tensor, dim, descending)
1313 }
1314
1315 /// Sort the elements of the input `tensor` by value along a given dimension.
1316 ///
1317 /// This sort is unstable (i.e., may reorder equal elements).
1318 ///
1319 /// # Arguments
1320 ///
1321 /// * `tensor` - The input tensor.
1322 /// * `dim` - The axis along which to sort.
1323 ///
1324 /// # Returns
1325 ///
1326 /// A tensor with the same shape as the input tensor and corresponding indices, where
1327 /// the elements are sorted by value and the indices map back to the original input tensor.
1328 fn int_sort_with_indices(
1329 tensor: IntTensor<B>,
1330 dim: usize,
1331 descending: bool,
1332 ) -> (IntTensor<B>, IntTensor<B>) {
1333 let dtype = tensor.dtype();
1334 sort_with_indices::<B, Int>(tensor, dim, descending, dtype.into())
1335 }
1336
1337 /// Returns the indices that sort the elements of the input `tensor` by value
1338 /// along a given dimension.
1339 ///
1340 /// This sort is unstable (i.e., may reorder equal elements).
1341 ///
1342 /// # Arguments
1343 ///
1344 /// * `tensor` - The input tensor.
1345 /// * `dim` - The axis along which to sort.
1346 /// * `descending` - The sorting order.
1347 ///
1348 /// # Returns
1349 ///
1350 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1351 fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1352 let dtype = tensor.dtype();
1353 argsort::<B, Int>(tensor, dim, descending, dtype.into())
1354 }
1355
1356 /// Bitwise AND operation for Int Tensors
1357 fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1358
1359 /// Bitwise AND operation for Int Tensors with a scalar
1360 fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1361
1362 /// Bitwise OR operation for Int Tensors
1363 fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1364
1365 /// Bitwise OR operation for Int Tensors with a scalar
1366 fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1367
1368 /// Bitwise XOR operation for Int Tensors
1369 fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1370
1371 /// Bitwise XOR operation for Int Tensors with a scalar
1372 fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1373
1374 /// Bitwise NOT operation for Int Tensors
1375 fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;
1376
1377 /// Bitwise left shift operation for Int Tensors
1378 fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1379
1380 /// Bitwise left shift operation for Int Tensors with a scalar
1381 fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1382
1383 /// Bitwise right shift operation for Int Tensors
1384 fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;
1385
1386 /// Bitwise right shift operation for Int Tensors with a scalar
1387 fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;
1388
1389 /// Converts a tensor to another integer data type.
1390 ///
1391 /// # Arguments
1392 ///
1393 /// * `tensor` - The tensor to convert.
1394 /// * `dtype` - The target data type.
1395 ///
1396 /// # Returns
1397 ///
1398 /// A tensor with the same values as `tensor` but in the target integer data type.
1399 fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;
1400
1401 /// Unfold windows along a dimension.
1402 ///
1403 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1404 /// where windows are advanced by `step` at each index.
1405 ///
1406 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1407 ///
1408 /// # Arguments
1409 ///
1410 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1411 /// * `dim` - the selected dim.
1412 /// * `size` - the size of each unfolded window.
1413 /// * `step` - the step between each window.
1414 ///
1415 /// # Returns
1416 ///
1417 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1418 fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;
1419}