1use std::ops::Range;
2
3use burn_backend::{
4 BoolDType, Distribution, ExecutionError, FloatDType, IntDType, Scalar, Shape, TensorData,
5 TensorMetadata,
6 ops::{FloatTensorOps, IntTensorOps},
7 tensor::IntTensor,
8};
9
10use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
11
12use super::TchOps;
13
14impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
15 fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
16 match data.dtype {
17 burn_backend::DType::I64 => TchTensor::from_data::<i64>(data, (*device).into()),
18 burn_backend::DType::I32 => TchTensor::from_data::<i32>(data, (*device).into()),
19 burn_backend::DType::I16 => TchTensor::from_data::<i16>(data, (*device).into()),
20 burn_backend::DType::I8 => TchTensor::from_data::<i8>(data, (*device).into()),
21 burn_backend::DType::U8 => TchTensor::from_data::<u8>(data, (*device).into()),
22 _ => unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype),
23 }
24 }
25
26 fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
27 TchOps::repeat_dim(tensor, dim, times)
28 }
29
30 async fn int_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
31 let shape = tensor.shape();
32 let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
33 let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
34 Ok(TensorData::new(values.unwrap(), shape))
35 }
36
37 fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
38 TchOps::to_device(tensor, device)
39 }
40
41 fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
42 TchOps::reshape(tensor, shape)
43 }
44
45 fn int_device(tensor: &TchTensor) -> LibTorchDevice {
46 tensor.tensor.device().into()
47 }
48
49 fn int_empty(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
50 let tensor = tch::Tensor::empty(
51 TchShape::from(shape).dims,
52 (dtype.into_kind(), (*device).into()),
53 );
54
55 TchTensor::new(tensor)
56 }
57
58 fn int_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
59 TchOps::slice_with_steps(tensor, slices)
60 }
61
62 fn int_slice_assign(
63 tensor: TchTensor,
64 slices: &[burn_backend::Slice],
65 value: TchTensor,
66 ) -> TchTensor {
67 TchOps::slice_assign(tensor, slices, value)
68 }
69
70 fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
71 TchOps::cat(tensors, dim)
72 }
73
74 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
75 let int_dtype = lhs.dtype();
76 let lhs = Self::int_into_float(lhs, FloatDType::F32);
77 let rhs = Self::int_into_float(rhs, FloatDType::F32);
78 let out = lhs.tensor.f_matmul(&rhs.tensor).unwrap();
79 Self::float_into_int(TchTensor::new(out), int_dtype.into())
80 }
81
82 fn int_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
83 TchOps::equal(lhs, rhs)
84 }
85
86 fn int_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
87 TchOps::equal_elem(lhs, rhs.elem::<i64>())
88 }
89
90 fn int_greater(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
91 TchOps::greater(lhs, rhs)
92 }
93
94 fn int_greater_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
95 TchOps::greater_elem(lhs, rhs.elem::<i64>())
96 }
97
98 fn int_greater_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
99 TchOps::greater_equal(lhs, rhs)
100 }
101
102 fn int_greater_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
103 TchOps::greater_equal_elem(lhs, rhs.elem::<i64>())
104 }
105
106 fn int_lower(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
107 TchOps::lower(lhs, rhs)
108 }
109
110 fn int_lower_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
111 TchOps::lower_elem(lhs, rhs.elem::<i64>())
112 }
113
114 fn int_lower_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
115 TchOps::lower_equal(lhs, rhs)
116 }
117
118 fn int_lower_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
119 TchOps::lower_equal_elem(lhs, rhs.elem::<i64>())
120 }
121
122 fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
123 TchOps::add(lhs, rhs)
124 }
125
126 fn int_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
127 lhs.unary_ops(
128 |mut tensor| tensor.f_add_scalar_(rhs.elem::<i64>()).unwrap(),
129 |tensor| tensor.f_add_scalar(rhs.elem::<i64>()).unwrap(),
130 )
131 }
132
133 fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
134 TchOps::sub(lhs, rhs)
135 }
136
137 fn int_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
138 lhs.unary_ops(
139 |mut tensor| tensor.f_sub_scalar_(rhs.elem::<i64>()).unwrap(),
140 |tensor| tensor.f_sub_scalar(rhs.elem::<i64>()).unwrap(),
141 )
142 }
143
144 fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
145 TchOps::mul(lhs, rhs)
146 }
147
148 fn int_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
149 lhs.unary_ops(
150 |mut tensor| tensor.f_mul_scalar_(rhs.elem::<i64>()).unwrap(),
151 |tensor| tensor.f_mul_scalar(rhs.elem::<i64>()).unwrap(),
152 )
153 }
154
155 fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
156 let dtype = lhs.tensor.kind();
157 let copy = false;
158 let non_blocking = true;
159 let lhs: TchTensor =
160 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
161 let rhs: TchTensor =
162 TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
163
164 let out = TchOps::div(lhs, rhs);
165
166 TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
167 }
168
169 fn int_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
170 let dtype = lhs.tensor.kind();
171 let copy = false;
172 let non_blocking = true;
173 let lhs: TchTensor =
174 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
175
176 let out: TchTensor = lhs.unary_ops(
177 |mut tensor| tensor.f_div_scalar_(rhs.elem::<i64>()).unwrap(),
178 |tensor| tensor.f_div_scalar(rhs.elem::<i64>()).unwrap(),
179 );
180
181 TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
182 }
183
184 fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
185 let dtype = lhs.tensor.kind();
186 let copy = false;
187 let non_blocking = true;
188 let lhs: TchTensor =
189 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
190 let rhs: TchTensor =
191 TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
192
193 let out = TchOps::remainder(lhs, rhs);
194
195 TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
196 }
197
198 fn int_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
199 lhs.unary_ops(
200 |tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),
201 |tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),
202 )
203 }
204
205 fn int_zeros(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
206 let shape = TchShape::from(shape);
207 let device: tch::Device = (*device).into();
208
209 TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
210 }
211
212 fn int_ones(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
213 let shape = TchShape::from(shape);
214 let device: tch::Device = (*device).into();
215
216 TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
217 }
218
219 fn int_full(
220 shape: Shape,
221 fill_value: Scalar,
222 device: &LibTorchDevice,
223 dtype: IntDType,
224 ) -> TchTensor {
225 let shape = TchShape::from(shape);
226 let device: tch::Device = (*device).into();
227
228 TchTensor::new(tch::Tensor::full(
229 shape.dims,
230 fill_value.elem::<i64>(),
231 (dtype.into_kind(), device),
232 ))
233 }
234
235 fn int_sum(tensor: TchTensor) -> TchTensor {
236 TchOps::sum(tensor)
237 }
238
239 fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
240 TchOps::sum_dim(tensor, dim)
241 }
242
243 fn int_prod(tensor: TchTensor) -> TchTensor {
244 TchOps::prod(tensor)
245 }
246
247 fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
248 TchOps::prod_dim(tensor, dim)
249 }
250
251 fn int_mean(tensor: TchTensor) -> TchTensor {
252 let dtype = tensor.tensor.kind();
253 let tensor: TchTensor =
254 TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
255 let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
256
257 TchTensor::new(output.tensor.to_dtype(dtype, true, false))
258 }
259
260 fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
261 let dtype = tensor.tensor.kind();
262 let tensor: TchTensor =
263 TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
264
265 let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
266
267 TchTensor::new(output.tensor.to_dtype(dtype, true, false))
268 }
269
270 fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
271 TchOps::cumsum(tensor, dim)
272 }
273
274 fn int_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
275 TchOps::cumprod(tensor, dim)
276 }
277
278 fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
279 TchOps::cummin(tensor, dim)
280 }
281
282 fn int_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
283 TchOps::cummax(tensor, dim)
284 }
285
286 fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
287 TchOps::gather(dim, tensor, indices)
288 }
289
290 fn int_scatter_add(
291 dim: usize,
292 tensor: TchTensor,
293 indices: TchTensor,
294 value: TchTensor,
295 ) -> TchTensor {
296 TchOps::scatter(dim, tensor, indices, value)
297 }
298
299 fn int_scatter_nd(
300 data: TchTensor,
301 indices: TchTensor,
302 values: TchTensor,
303 reduction: burn_backend::tensor::IndexingUpdateOp,
304 ) -> TchTensor {
305 TchOps::scatter_nd(data, indices, values, reduction)
306 }
307
308 fn int_gather_nd(data: TchTensor, indices: TchTensor) -> TchTensor {
309 TchOps::gather_nd(data, indices)
310 }
311
312 fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
313 TchOps::index_select_dim(tensor, dim, indices)
314 }
315
316 fn int_select_add(
317 tensor: TchTensor,
318 dim: usize,
319 indices: TchTensor,
320 value: TchTensor,
321 ) -> TchTensor {
322 TchOps::select_assign(tensor, dim, indices, value)
323 }
324
325 fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {
326 TchTensor::binary_ops_tensor(
327 tensor,
328 source,
329 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
330 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
331 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
332 )
333 }
334
335 fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {
336 let value = value.elem::<i64>();
337 tensor.unary_ops(
338 |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
339 |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
340 )
341 }
342
343 fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
344 TchOps::argmax(tensor, dim)
345 }
346
347 fn int_argtopk(_tensor: TchTensor, _dim: usize, _k: usize) -> TchTensor {
348 panic!("argtopk not implemented for torch")
349 }
350
351 fn int_topk(tensor: TchTensor, dim: usize, k: usize) -> TchTensor {
352 TchOps::topk(tensor, dim, k)
353 }
354
355 fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
356 TchOps::argmin(tensor, dim)
357 }
358
359 fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
360 TchOps::max_dim(tensor, dim)
361 }
362
363 fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
364 TchOps::max_dim_with_indices(tensor, dim)
365 }
366
367 fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
368 TchOps::min_dim(tensor, dim)
369 }
370
371 fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
372 TchOps::min_dim_with_indices(tensor, dim)
373 }
374
375 fn int_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {
376 TchOps::clamp_min(tensor, min.elem::<i64>())
377 }
378
379 fn int_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {
380 TchOps::clamp_max(tensor, max.elem::<i64>())
381 }
382
383 fn int_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {
384 TchOps::clamp(tensor, min.elem::<i64>(), max.elem::<i64>())
385 }
386
387 fn int_abs(tensor: TchTensor) -> TchTensor {
388 tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
389 }
390
391 fn int_into_float(tensor: TchTensor, out_dtype: FloatDType) -> TchTensor {
392 let tensor = tensor.tensor.to_kind(out_dtype.into_kind());
393 TchTensor::new(tensor)
394 }
395
396 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
397 TchOps::swap_dims(tensor, dim1, dim2)
398 }
399
400 fn int_random(
401 shape: Shape,
402 distribution: Distribution,
403 device: &LibTorchDevice,
404 dtype: IntDType,
405 ) -> TchTensor {
406 match distribution {
407 Distribution::Default => TchTensor::new(tch::Tensor::randint_low(
408 0,
409 255,
410 shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),
411 (dtype.into_kind(), (*device).into()),
412 )),
413 Distribution::Bernoulli(prob) => {
414 let mut tensor = TchTensor::empty(shape, *device, dtype.into());
415 tensor
416 .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
417 .unwrap()
418 }
419 Distribution::Uniform(from, to) => TchTensor::new(tch::Tensor::randint_low(
420 from as i64,
421 to as i64,
422 shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),
423 (dtype.into_kind(), (*device).into()),
424 )),
425 Distribution::Normal(mean, std) => {
426 let mut tensor = TchTensor::empty(shape, *device, dtype.into());
427 tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
428 }
429 }
430 }
431
432 fn int_arange(range: Range<i64>, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {
433 let device: tch::Device = (*device).into();
434 let mut tensor = tch::Tensor::arange(range.end - range.start, (dtype.into_kind(), device));
435
436 if range.start != 0 {
437 tensor = tensor.f_add_scalar_(range.start).unwrap();
438 }
439
440 TchTensor::new(tensor)
441 }
442
443 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
444 TchOps::permute(tensor, axes)
445 }
446
447 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
448 TchOps::flip(tensor, axes)
449 }
450
451 fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
452 TchOps::sign(tensor)
453 }
454
455 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
456 TchOps::expand(tensor, shape)
457 }
458
459 fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
460 TchOps::sort(tensor, dim, descending)
461 }
462
463 fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
464 TchOps::argsort(tensor, dim, descending)
465 }
466
467 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
468 TchOps::bitwise_and(lhs, rhs)
469 }
470
471 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
472 TchOps::bitwise_or(lhs, rhs)
473 }
474
475 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
476 TchOps::bitwise_xor(lhs, rhs)
477 }
478
479 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
480 TchOps::bitwise_not(tensor)
481 }
482
483 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
484 TchOps::bitwise_and_scalar(lhs, rhs.elem::<i64>())
485 }
486
487 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
488 TchOps::bitwise_or_scalar(lhs, rhs.elem::<i64>())
489 }
490
491 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
492 TchOps::bitwise_xor_scalar(lhs, rhs.elem::<i64>())
493 }
494
495 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
496 TchOps::bitwise_left_shift(lhs, rhs)
497 }
498
499 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
500 TchOps::bitwise_right_shift(lhs, rhs)
501 }
502
503 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
504 TchOps::bitwise_left_shift_scalar(lhs, rhs.elem::<i64>())
505 }
506
507 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
508 TchOps::bitwise_right_shift_scalar(lhs, rhs.elem::<i64>())
509 }
510
511 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
512 let kind = dtype.into_kind();
517
518 if tensor.tensor.kind() == kind {
519 tensor
520 } else {
521 TchTensor::new(tensor.tensor.to_kind(kind))
522 }
523 }
524
525 fn int_unfold(
526 tensor: IntTensor<Self>,
527 dim: usize,
528 size: usize,
529 step: usize,
530 ) -> IntTensor<Self> {
531 TchOps::unfold(tensor, dim, size, step)
532 }
533
534 fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
535 TchOps::pow(lhs, rhs)
536 }
537
538 fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
539 lhs.unary_ops(
540 |mut tensor| tensor.f_pow_(rhs.elem::<i64>()).unwrap(),
541 |tensor| tensor.pow_tensor_scalar(rhs.elem::<i64>()),
542 )
543 }
544}