1use alloc::vec::Vec;
2use burn_backend::{
3 BoolDType, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice, TensorData,
4 ops::IntTensorOps,
5 tensor::{BoolTensor, FloatTensor, IntTensor},
6};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl IntTensorOps<Self> for Dispatch {
12 fn int_empty(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
13 creation_op!(Int, device, |device| B::int_empty(shape, device, dtype))
14 }
15
16 async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
17 unary_op!(tensor, int, |tensor| B::int_into_data(tensor).await)
18 }
19
20 fn int_from_data(data: TensorData, device: &DispatchDevice) -> IntTensor<Self> {
21 creation_op!(Int, device, |device| B::int_from_data(data, device))
22 }
23
24 fn int_device(tensor: &IntTensor<Self>) -> DispatchDevice {
25 tensor.device()
26 }
27
28 fn int_to_device(tensor: IntTensor<Self>, device: &DispatchDevice) -> IntTensor<Self> {
29 to_device!(Int, int, tensor, device, int_to_device, |inner, device| {
30 let data = burn_backend::read_sync(B1::int_into_data(inner)).expect("Should read data");
31 B2::int_from_data(data, device)
32 })
33 }
34
35 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
36 unary_op!(tensor, int, |tensor| B::int_reshape(tensor, shape) => Int)
37 }
38
39 fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
40 unary_op!(tensor, int, |tensor| B::int_slice(tensor, slices) => Int)
41 }
42
43 fn int_slice_assign(
44 tensor: IntTensor<Self>,
45 slices: &[Slice],
46 value: IntTensor<Self>,
47 ) -> IntTensor<Self> {
48 binary_op!((tensor, int), (value, int), |tensor, value| B::int_slice_assign(tensor, slices, value) => Int)
49 }
50
51 fn int_into_float(tensor: IntTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
52 unary_op!(tensor, int, |tensor| B::int_into_float(tensor, out_dtype) => Float)
53 }
54
55 fn int_mask_where(
56 tensor: IntTensor<Self>,
57 mask: BoolTensor<Self>,
58 value: IntTensor<Self>,
59 ) -> IntTensor<Self> {
60 multi_op!(
61 inputs[(tensor, int), (mask, bool), (value, int)], => Int,
62 B::int_mask_where(tensor, mask, value)
63 )
64 }
65
66 fn int_mask_fill(
67 tensor: IntTensor<Self>,
68 mask: BoolTensor<Self>,
69 value: Scalar,
70 ) -> IntTensor<Self> {
71 binary_op!((tensor, int), (mask, bool), |tensor, mask| B::int_mask_fill(tensor, mask, value) => Int)
72 }
73
74 fn int_gather(
75 dim: usize,
76 tensor: IntTensor<Self>,
77 indices: IntTensor<Self>,
78 ) -> IntTensor<Self> {
79 binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_gather(dim, tensor, indices) => Int)
80 }
81
82 fn int_scatter_add(
83 dim: usize,
84 tensor: IntTensor<Self>,
85 indices: IntTensor<Self>,
86 value: IntTensor<Self>,
87 ) -> IntTensor<Self> {
88 multi_op!(
89 inputs[(tensor, int), (indices, int), (value, int)], => Int,
90 B::int_scatter_add(dim, tensor, indices, value)
91 )
92 }
93
94 fn int_scatter_nd(
95 data: IntTensor<Self>,
96 indices: IntTensor<Self>,
97 values: IntTensor<Self>,
98 reduction: burn_backend::tensor::IndexingUpdateOp,
99 ) -> IntTensor<Self> {
100 multi_op!(
101 inputs[(data, int), (indices, int), (values, int)], => Int,
102 B::int_scatter_nd(data, indices, values, reduction)
103 )
104 }
105
106 fn int_gather_nd(data: IntTensor<Self>, indices: IntTensor<Self>) -> IntTensor<Self> {
107 binary_op!((data, int), (indices, int), |data, indices| B::int_gather_nd(data, indices) => Int)
108 }
109
110 fn int_select(
111 tensor: IntTensor<Self>,
112 dim: usize,
113 indices: IntTensor<Self>,
114 ) -> IntTensor<Self> {
115 binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_select(tensor, dim, indices) => Int)
116 }
117
118 fn int_select_add(
119 tensor: IntTensor<Self>,
120 dim: usize,
121 indices: IntTensor<Self>,
122 value: IntTensor<Self>,
123 ) -> IntTensor<Self> {
124 multi_op!(
125 inputs[(tensor, int), (indices, int), (value, int)], => Int,
126 B::int_select_add(tensor, dim, indices, value)
127 )
128 }
129
130 fn int_equal(
131 lhs: IntTensor<Self>,
132 rhs: IntTensor<Self>,
133 out_dtype: BoolDType,
134 ) -> BoolTensor<Self> {
135 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_equal(lhs, rhs, out_dtype) => Bool)
136 }
137
138 fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
139 unary_op!(lhs, int, |lhs| B::int_equal_elem(lhs, rhs, out_dtype) => Bool)
140 }
141
142 fn int_greater(
143 lhs: IntTensor<Self>,
144 rhs: IntTensor<Self>,
145 out_dtype: BoolDType,
146 ) -> BoolTensor<Self> {
147 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater(lhs, rhs, out_dtype) => Bool)
148 }
149
150 fn int_greater_elem(
151 lhs: IntTensor<Self>,
152 rhs: Scalar,
153 out_dtype: BoolDType,
154 ) -> BoolTensor<Self> {
155 unary_op!(lhs, int, |lhs| B::int_greater_elem(lhs, rhs, out_dtype) => Bool)
156 }
157
158 fn int_greater_equal(
159 lhs: IntTensor<Self>,
160 rhs: IntTensor<Self>,
161 out_dtype: BoolDType,
162 ) -> BoolTensor<Self> {
163 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater_equal(lhs, rhs, out_dtype) => Bool)
164 }
165
166 fn int_greater_equal_elem(
167 lhs: IntTensor<Self>,
168 rhs: Scalar,
169 out_dtype: BoolDType,
170 ) -> BoolTensor<Self> {
171 unary_op!(lhs, int, |lhs| B::int_greater_equal_elem(lhs, rhs, out_dtype) => Bool)
172 }
173
174 fn int_lower(
175 lhs: IntTensor<Self>,
176 rhs: IntTensor<Self>,
177 out_dtype: BoolDType,
178 ) -> BoolTensor<Self> {
179 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower(lhs, rhs, out_dtype) => Bool)
180 }
181
182 fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
183 unary_op!(lhs, int, |lhs| B::int_lower_elem(lhs, rhs, out_dtype) => Bool)
184 }
185
186 fn int_lower_equal(
187 lhs: IntTensor<Self>,
188 rhs: IntTensor<Self>,
189 out_dtype: BoolDType,
190 ) -> BoolTensor<Self> {
191 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower_equal(lhs, rhs, out_dtype) => Bool)
192 }
193
194 fn int_lower_equal_elem(
195 lhs: IntTensor<Self>,
196 rhs: Scalar,
197 out_dtype: BoolDType,
198 ) -> BoolTensor<Self> {
199 unary_op!(lhs, int, |lhs| B::int_lower_equal_elem(lhs, rhs, out_dtype) => Bool)
200 }
201
202 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
203 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_add(lhs, rhs) => Int)
204 }
205
206 fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
207 unary_op!(lhs, int, |lhs| B::int_add_scalar(lhs, rhs) => Int)
208 }
209
210 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
211 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_sub(lhs, rhs) => Int)
212 }
213
214 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
215 unary_op!(lhs, int, |lhs| B::int_sub_scalar(lhs, rhs) => Int)
216 }
217
218 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
219 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_mul(lhs, rhs) => Int)
220 }
221
222 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
223 unary_op!(lhs, int, |lhs| B::int_mul_scalar(lhs, rhs) => Int)
224 }
225
226 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
227 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_div(lhs, rhs) => Int)
228 }
229
230 fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
231 unary_op!(lhs, int, |lhs| B::int_div_scalar(lhs, rhs) => Int)
232 }
233
234 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
235 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_remainder(lhs, rhs) => Int)
236 }
237
238 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
239 unary_op!(lhs, int, |lhs| B::int_remainder_scalar(lhs, rhs) => Int)
240 }
241
242 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
243 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_matmul(lhs, rhs) => Int)
244 }
245
246 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
247 unary_op!(tensor, int, |tensor| B::int_sum(tensor) => Int)
248 }
249
250 fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
251 unary_op!(tensor, int, |tensor| B::int_sum_dim(tensor, dim) => Int)
252 }
253
254 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
255 unary_op!(tensor, int, |tensor| B::int_prod(tensor) => Int)
256 }
257
258 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
259 unary_op!(tensor, int, |tensor| B::int_prod_dim(tensor, dim) => Int)
260 }
261
262 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
263 unary_op!(tensor, int, |tensor| B::int_mean_dim(tensor, dim) => Int)
264 }
265
266 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
267 unary_op!(tensor, int, |tensor| B::int_cumsum(tensor, dim) => Int)
268 }
269
270 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
271 unary_op!(tensor, int, |tensor| B::int_cumprod(tensor, dim) => Int)
272 }
273
274 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
275 unary_op!(tensor, int, |tensor| B::int_cummin(tensor, dim) => Int)
276 }
277
278 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
279 unary_op!(tensor, int, |tensor| B::int_cummax(tensor, dim) => Int)
280 }
281
282 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
283 unary_op!(tensor, int, |tensor| B::int_argmax(tensor, dim) => Int)
284 }
285
286 fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
287 unary_op!(tensor, int, |tensor| B::int_argtopk(tensor, dim, k) => Int)
288 }
289
290 fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
291 unary_op!(tensor, int, |tensor| B::int_topk(tensor, dim, k) => Int)
292 }
293
294 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
295 unary_op!(tensor, int, |tensor| B::int_argmin(tensor, dim) => Int)
296 }
297
298 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
299 unary_op!(tensor, int, |tensor| B::int_abs(tensor) => Int)
300 }
301
302 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
303 unary_op!(tensor, int, |tensor| B::int_swap_dims(tensor, dim1, dim2) => Int)
304 }
305
306 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
307 unary_op!(tensor, int, |tensor| B::int_permute(tensor, axes) => Int)
308 }
309
310 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
311 unary_op!(tensor, int, |tensor| B::int_flip(tensor, axes) => Int)
312 }
313
314 fn int_random(
315 shape: Shape,
316 distribution: burn_backend::Distribution,
317 device: &DispatchDevice,
318 dtype: IntDType,
319 ) -> IntTensor<Self> {
320 creation_op!(Int, device, |device| {
321 B::int_random(shape, distribution, device, dtype)
322 })
323 }
324
325 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
326 unary_op!(tensor, int, |tensor| B::int_expand(tensor, shape) => Int)
327 }
328
329 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
330 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_and(lhs, rhs) => Int)
331 }
332
333 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
334 unary_op!(lhs, int, |lhs| B::bitwise_and_scalar(lhs, rhs) => Int)
335 }
336
337 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
338 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_or(lhs, rhs) => Int)
339 }
340
341 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
342 unary_op!(lhs, int, |lhs| B::bitwise_or_scalar(lhs, rhs) => Int)
343 }
344
345 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
346 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_xor(lhs, rhs) => Int)
347 }
348
349 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
350 unary_op!(lhs, int, |lhs| B::bitwise_xor_scalar(lhs, rhs) => Int)
351 }
352
353 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
354 unary_op!(tensor, int, |tensor| B::bitwise_not(tensor) => Int)
355 }
356
357 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
358 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_left_shift(lhs, rhs) => Int)
359 }
360
361 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
362 unary_op!(lhs, int, |lhs| B::bitwise_left_shift_scalar(lhs, rhs) => Int)
363 }
364
365 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
366 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_right_shift(lhs, rhs) => Int)
367 }
368
369 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
370 unary_op!(lhs, int, |lhs| B::bitwise_right_shift_scalar(lhs, rhs) => Int)
371 }
372
373 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
374 unary_op!(tensor, int, |tensor| B::int_cast(tensor, dtype) => Int)
375 }
376
377 fn int_unfold(
378 tensor: IntTensor<Self>,
379 dim: usize,
380 size: usize,
381 step: usize,
382 ) -> IntTensor<Self> {
383 unary_op!(tensor, int, |tensor| B::int_unfold(tensor, dim, size, step) => Int)
384 }
385
386 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
387 unary_op!(tensor, int, |tensor| B::int_repeat_dim(tensor, dim, times) => Int)
388 }
389
390 fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
391 vec_op!(tensors, int, |tensors| B::int_cat(tensors, dim) => Int)
392 }
393
394 fn int_not_equal(
395 lhs: IntTensor<Self>,
396 rhs: IntTensor<Self>,
397 out_dtype: BoolDType,
398 ) -> BoolTensor<Self> {
399 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_not_equal(lhs, rhs, out_dtype) => Bool)
400 }
401
402 fn int_not_equal_elem(
403 lhs: IntTensor<Self>,
404 rhs: Scalar,
405 out_dtype: BoolDType,
406 ) -> BoolTensor<Self> {
407 unary_op!(lhs, int, |lhs| B::int_not_equal_elem(lhs, rhs, out_dtype) => Bool)
408 }
409
410 fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
411 binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_powi(lhs, rhs) => Int)
412 }
413
414 fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
415 unary_op!(lhs, int, |lhs| B::int_powi_scalar_impl(lhs, rhs) => Int)
416 }
417
418 fn int_clamp_min(tensor: IntTensor<Self>, min: Scalar) -> IntTensor<Self> {
419 unary_op!(tensor, int, |tensor| B::int_clamp_min(tensor, min) => Int)
420 }
421
422 fn int_clamp_max(tensor: IntTensor<Self>, max: Scalar) -> IntTensor<Self> {
423 unary_op!(tensor, int, |tensor| B::int_clamp_max(tensor, max) => Int)
424 }
425
426 fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
427 unary_op!(tensor, int, |tensor| B::int_clamp(tensor, min, max) => Int)
428 }
429
430 fn int_neg(tensor: IntTensor<Self>) -> IntTensor<Self> {
431 unary_op!(tensor, int, |tensor| B::int_neg(tensor) => Int)
432 }
433
434 fn int_zeros(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
435 creation_op!(Int, device, |device| B::int_zeros(shape, device, dtype))
436 }
437
438 fn int_ones(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {
439 creation_op!(Int, device, |device| B::int_ones(shape, device, dtype))
440 }
441
442 fn int_full(
443 shape: Shape,
444 fill_value: Scalar,
445 device: &DispatchDevice,
446 dtype: IntDType,
447 ) -> IntTensor<Self> {
448 creation_op!(Int, device, |device| B::int_full(
449 shape, fill_value, device, dtype
450 ))
451 }
452
453 fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
454 unary_op!(tensor, int, |tensor| B::int_mean(tensor) => Int)
455 }
456
457 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
458 unary_op!(tensor, int, |tensor| B::int_max(tensor) => Int)
459 }
460
461 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
462 unary_op!(tensor, int, |tensor| B::int_max_dim(tensor, dim) => Int)
463 }
464
465 fn int_max_dim_with_indices(
466 tensor: IntTensor<Self>,
467 dim: usize,
468 ) -> (IntTensor<Self>, IntTensor<Self>) {
469 multi_op!(
470 inputs[(tensor, int)],
471 outputs[(out, Int), (indices, Int)],
472 B::int_max_dim_with_indices(tensor, dim)
473 )
474 }
475
476 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
477 unary_op!(tensor, int, |tensor| B::int_max_abs(tensor) => Int)
478 }
479
480 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
481 unary_op!(tensor, int, |tensor| B::int_max_abs_dim(tensor, dim) => Int)
482 }
483
484 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
485 unary_op!(tensor, int, |tensor| B::int_min(tensor) => Int)
486 }
487
488 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
489 unary_op!(tensor, int, |tensor| B::int_min_dim(tensor, dim) => Int)
490 }
491
492 fn int_min_dim_with_indices(
493 tensor: IntTensor<Self>,
494 dim: usize,
495 ) -> (IntTensor<Self>, IntTensor<Self>) {
496 multi_op!(
497 inputs[(tensor, int)],
498 outputs[(out, Int), (indices, Int)],
499 B::int_min_dim_with_indices(tensor, dim)
500 )
501 }
502
503 fn int_transpose(tensor: IntTensor<Self>) -> IntTensor<Self> {
504 unary_op!(tensor, int, |tensor| B::int_transpose(tensor) => Int)
505 }
506
507 fn int_arange_step(
508 range: core::ops::Range<i64>,
509 step: usize,
510 device: &DispatchDevice,
511 dtype: IntDType,
512 ) -> IntTensor<Self> {
513 creation_op!(Int, device, |device| B::int_arange_step(
514 range, step, device, dtype
515 ))
516 }
517
518 fn int_arange(
519 range: core::ops::Range<i64>,
520 device: &DispatchDevice,
521 dtype: IntDType,
522 ) -> IntTensor<Self> {
523 creation_op!(Int, device, |device| B::int_arange(range, device, dtype))
524 }
525
526 fn int_any(tensor: IntTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
527 unary_op!(tensor, int, |tensor| B::int_any(tensor, out_dtype) => Bool)
528 }
529
530 fn int_any_dim(tensor: IntTensor<Self>, dim: usize, out_dtype: BoolDType) -> BoolTensor<Self> {
531 unary_op!(tensor, int, |tensor| B::int_any_dim(tensor, dim, out_dtype) => Bool)
532 }
533
534 fn int_all(tensor: IntTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
535 unary_op!(tensor, int, |tensor| B::int_all(tensor, out_dtype) => Bool)
536 }
537
538 fn int_all_dim(tensor: IntTensor<Self>, dim: usize, out_dtype: BoolDType) -> BoolTensor<Self> {
539 unary_op!(tensor, int, |tensor| B::int_all_dim(tensor, dim, out_dtype) => Bool)
540 }
541
542 fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
543 unary_op!(tensor, int, |tensor| B::int_sign(tensor) => Int)
544 }
545
546 fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
547 unary_op!(tensor, int, |tensor| B::int_sort(tensor, dim, descending) => Int)
548 }
549
550 fn int_sort_with_indices(
551 tensor: IntTensor<Self>,
552 dim: usize,
553 descending: bool,
554 ) -> (IntTensor<Self>, IntTensor<Self>) {
555 multi_op!(
556 inputs[(tensor, int)],
557 outputs[(out, Int), (indices, Int)],
558 B::int_sort_with_indices(tensor, dim, descending)
559 )
560 }
561
562 fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
563 unary_op!(tensor, int, |tensor| B::int_argsort(tensor, dim, descending) => Int)
564 }
565}