1use std::ops::Range;
2
3use burn_tensor::{
4 Distribution, IntDType, Shape, TensorData, TensorMetadata,
5 backend::Backend,
6 ops::{FloatTensorOps, IntTensor, IntTensorOps},
7};
8
9use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
10
11use super::TchOps;
12
13impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
14 fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
15 match data.dtype {
16 burn_tensor::DType::I64 => TchTensor::from_data::<i64>(data, (*device).into()),
17 _ => unimplemented!("Unsupported dtype for `int_from_data`"),
18 }
19 }
20
21 fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
22 TchOps::repeat_dim(tensor, dim, times)
23 }
24
25 async fn int_into_data(tensor: TchTensor) -> TensorData {
26 let shape = tensor.shape();
27 let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
28 let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
29 TensorData::new(values.unwrap(), shape)
30 }
31
32 fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
33 TchOps::to_device(tensor, device)
34 }
35
36 fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
37 TchOps::reshape(tensor, shape)
38 }
39
40 fn int_device(tensor: &TchTensor) -> LibTorchDevice {
41 tensor.tensor.device().into()
42 }
43
44 fn int_empty(
45 shape: Shape,
46 device: &<LibTorch<E> as Backend>::Device,
47 dtype: IntDType,
48 ) -> TchTensor {
49 let tensor = tch::Tensor::empty(
50 TchShape::from(shape).dims,
51 (dtype.into_kind(), (*device).into()),
52 );
53
54 TchTensor::new(tensor)
55 }
56
57 fn int_slice(tensor: TchTensor, slices: &[burn_tensor::Slice]) -> TchTensor {
58 TchOps::slice_with_steps(tensor, slices)
59 }
60
61 fn int_slice_assign(
62 tensor: TchTensor,
63 slices: &[burn_tensor::Slice],
64 value: TchTensor,
65 ) -> TchTensor {
66 TchOps::slice_assign(tensor, slices, value)
67 }
68
69 fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
70 TchOps::cat(tensors, dim)
71 }
72
73 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
74 let lhs = Self::int_into_float(lhs);
75 let rhs = Self::int_into_float(rhs);
76 let out = lhs.tensor.f_matmul(&rhs.tensor).unwrap();
77 Self::float_into_int(TchTensor::new(out))
78 }
79
80 fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
81 TchOps::equal(lhs, rhs)
82 }
83
84 fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
85 TchOps::equal_elem(lhs, rhs)
86 }
87
88 fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
89 TchOps::greater(lhs, rhs)
90 }
91
92 fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
93 TchOps::greater_elem(lhs, rhs)
94 }
95
96 fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
97 TchOps::greater_equal(lhs, rhs)
98 }
99
100 fn int_greater_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
101 TchOps::greater_equal_elem(lhs, rhs)
102 }
103
104 fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
105 TchOps::lower(lhs, rhs)
106 }
107
108 fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
109 TchOps::lower_elem(lhs, rhs)
110 }
111
112 fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
113 TchOps::lower_equal(lhs, rhs)
114 }
115
116 fn int_lower_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
117 TchOps::lower_equal_elem(lhs, rhs)
118 }
119
120 fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
121 TchOps::add(lhs, rhs)
122 }
123
124 fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
125 lhs.unary_ops(
126 |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
127 |tensor| tensor.f_add_scalar(rhs).unwrap(),
128 )
129 }
130
131 fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
132 TchOps::sub(lhs, rhs)
133 }
134
135 fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
136 lhs.unary_ops(
137 |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
138 |tensor| tensor.f_sub_scalar(rhs).unwrap(),
139 )
140 }
141
142 fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
143 TchOps::mul(lhs, rhs)
144 }
145
146 fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
147 lhs.unary_ops(
148 |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
149 |tensor| tensor.f_mul_scalar(rhs).unwrap(),
150 )
151 }
152
153 fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
154 let dtype = lhs.tensor.kind();
155 let copy = false;
156 let non_blocking = true;
157 let lhs: TchTensor =
158 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
159 let rhs: TchTensor =
160 TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
161
162 let out = TchOps::div(lhs, rhs);
163
164 TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
165 }
166
167 fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
168 let dtype = lhs.tensor.kind();
169 let copy = false;
170 let non_blocking = true;
171 let lhs: TchTensor =
172 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
173
174 let out: TchTensor = lhs.unary_ops(
175 |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
176 |tensor| tensor.f_div_scalar(rhs).unwrap(),
177 );
178
179 TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
180 }
181
182 fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
183 let dtype = lhs.tensor.kind();
184 let copy = false;
185 let non_blocking = true;
186 let lhs: TchTensor =
187 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
188 let rhs: TchTensor =
189 TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
190
191 let out = TchOps::remainder(lhs, rhs);
192
193 TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
194 }
195
196 fn int_remainder_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
197 lhs.unary_ops(
198 |tensor| tensor.f_remainder(rhs).unwrap(),
199 |tensor| tensor.f_remainder(rhs).unwrap(),
200 )
201 }
202
203 fn int_neg(tensor: TchTensor) -> TchTensor {
204 Self::int_mul_scalar(tensor, -1)
205 }
206
207 fn int_zeros(
208 shape: Shape,
209 device: &<LibTorch<E> as Backend>::Device,
210 dtype: IntDType,
211 ) -> TchTensor {
212 let shape = TchShape::from(shape);
213 let device: tch::Device = (*device).into();
214
215 TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
216 }
217
218 fn int_ones(
219 shape: Shape,
220 device: &<LibTorch<E> as Backend>::Device,
221 dtype: IntDType,
222 ) -> TchTensor {
223 let shape = TchShape::from(shape);
224 let device: tch::Device = (*device).into();
225
226 TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
227 }
228
229 fn int_full(
230 shape: Shape,
231 fill_value: i64,
232 device: &<LibTorch<E> as Backend>::Device,
233 dtype: IntDType,
234 ) -> TchTensor {
235 let shape = TchShape::from(shape);
236 let device: tch::Device = (*device).into();
237
238 TchTensor::new(tch::Tensor::full(
239 shape.dims,
240 fill_value,
241 (dtype.into_kind(), device),
242 ))
243 }
244
245 fn int_sum(tensor: TchTensor) -> TchTensor {
246 TchOps::sum(tensor)
247 }
248
249 fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
250 TchOps::sum_dim(tensor, dim)
251 }
252
253 fn int_prod(tensor: TchTensor) -> TchTensor {
254 TchOps::prod(tensor)
255 }
256
257 fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
258 TchOps::prod_dim(tensor, dim)
259 }
260
261 fn int_mean(tensor: TchTensor) -> TchTensor {
262 let dtype = tensor.tensor.kind();
263 let tensor: TchTensor =
264 TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
265 let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
266
267 TchTensor::new(output.tensor.to_dtype(dtype, true, false))
268 }
269
270 fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
271 let dtype = tensor.tensor.kind();
272 let tensor: TchTensor =
273 TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
274
275 let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
276
277 TchTensor::new(output.tensor.to_dtype(dtype, true, false))
278 }
279
280 fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
281 TchOps::cumsum(tensor, dim)
282 }
283
284 fn int_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
285 TchOps::cumprod(tensor, dim)
286 }
287
288 fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
289 TchOps::cummin(tensor, dim)
290 }
291
292 fn int_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
293 TchOps::cummax(tensor, dim)
294 }
295
296 fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
297 TchOps::gather(dim, tensor, indices)
298 }
299
300 fn int_scatter(
301 dim: usize,
302 tensor: TchTensor,
303 indices: TchTensor,
304 value: TchTensor,
305 ) -> TchTensor {
306 TchOps::scatter(dim, tensor, indices, value)
307 }
308
309 fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
310 TchOps::index_select_dim(tensor, dim, indices)
311 }
312
313 fn int_select_assign(
314 tensor: TchTensor,
315 dim: usize,
316 indices: TchTensor,
317 value: TchTensor,
318 ) -> TchTensor {
319 TchOps::select_assign(tensor, dim, indices, value)
320 }
321
322 fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {
323 TchTensor::binary_ops_tensor(
324 tensor,
325 source,
326 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
327 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
328 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
329 )
330 }
331
332 fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: i64) -> TchTensor {
333 tensor.unary_ops(
334 |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
335 |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
336 )
337 }
338
339 fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
340 TchOps::argmax(tensor, dim)
341 }
342
343 fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
344 TchOps::argmin(tensor, dim)
345 }
346
347 fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
348 TchOps::max_dim(tensor, dim)
349 }
350
351 fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
352 TchOps::max_dim_with_indices(tensor, dim)
353 }
354
355 fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
356 TchOps::min_dim(tensor, dim)
357 }
358
359 fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
360 TchOps::min_dim_with_indices(tensor, dim)
361 }
362
363 fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor {
364 TchOps::clamp_min(tensor, min)
365 }
366
367 fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor {
368 TchOps::clamp_max(tensor, max)
369 }
370
371 fn int_clamp(tensor: TchTensor, min: i64, max: i64) -> TchTensor {
372 TchOps::clamp(tensor, min, max)
373 }
374
375 fn int_abs(tensor: TchTensor) -> TchTensor {
376 tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
377 }
378
379 fn int_into_float(tensor: TchTensor) -> TchTensor {
380 let tensor = tensor.tensor.to_kind(E::KIND);
381 TchTensor::new(tensor)
382 }
383
384 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
385 TchOps::swap_dims(tensor, dim1, dim2)
386 }
387
388 fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
389 match distribution {
390 Distribution::Default => TchTensor::new(tch::Tensor::randint_low(
391 0,
392 255,
393 shape.into_iter().map(|i| i as i64).collect::<Vec<_>>(),
394 (tch::Kind::Int64, (*device).into()),
395 )),
396 Distribution::Bernoulli(prob) => {
397 let mut tensor = TchTensor::empty::<i64>(shape, *device);
398 tensor
399 .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
400 .unwrap()
401 }
402 Distribution::Uniform(from, to) => TchTensor::new(tch::Tensor::randint_low(
403 from as i64,
404 to as i64,
405 shape.into_iter().map(|i| i as i64).collect::<Vec<_>>(),
406 (tch::Kind::Int64, (*device).into()),
407 )),
408 Distribution::Normal(mean, std) => {
409 let mut tensor = TchTensor::empty::<i64>(shape, *device);
410 tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
411 }
412 }
413 }
414
415 fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor {
416 let device: tch::Device = (*device).into();
417 let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
418
419 if range.start != 0 {
420 tensor = tensor.f_add_scalar_(range.start).unwrap();
421 }
422
423 TchTensor::new(tensor)
424 }
425
426 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
427 TchOps::permute(tensor, axes)
428 }
429
430 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
431 TchOps::flip(tensor, axes)
432 }
433
434 fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
435 TchOps::sign(tensor)
436 }
437
438 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
439 TchOps::expand(tensor, shape)
440 }
441
442 fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
443 TchOps::sort(tensor, dim, descending)
444 }
445
446 fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
447 TchOps::argsort(tensor, dim, descending)
448 }
449
450 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
451 TchOps::bitwise_and(lhs, rhs)
452 }
453
454 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
455 TchOps::bitwise_or(lhs, rhs)
456 }
457
458 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
459 TchOps::bitwise_xor(lhs, rhs)
460 }
461
462 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
463 TchOps::bitwise_not(tensor)
464 }
465
466 fn bitwise_and_scalar(
467 lhs: IntTensor<Self>,
468 rhs: burn_tensor::ops::IntElem<Self>,
469 ) -> IntTensor<Self> {
470 TchOps::bitwise_and_scalar(lhs, rhs)
471 }
472
473 fn bitwise_or_scalar(
474 lhs: IntTensor<Self>,
475 rhs: burn_tensor::ops::IntElem<Self>,
476 ) -> IntTensor<Self> {
477 TchOps::bitwise_or_scalar(lhs, rhs)
478 }
479
480 fn bitwise_xor_scalar(
481 lhs: IntTensor<Self>,
482 rhs: burn_tensor::ops::IntElem<Self>,
483 ) -> IntTensor<Self> {
484 TchOps::bitwise_xor_scalar(lhs, rhs)
485 }
486
487 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
488 TchOps::bitwise_left_shift(lhs, rhs)
489 }
490
491 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
492 TchOps::bitwise_right_shift(lhs, rhs)
493 }
494
495 fn bitwise_left_shift_scalar(
496 lhs: IntTensor<Self>,
497 rhs: burn_tensor::ops::IntElem<Self>,
498 ) -> IntTensor<Self> {
499 TchOps::bitwise_left_shift_scalar(lhs, rhs)
500 }
501
502 fn bitwise_right_shift_scalar(
503 lhs: IntTensor<Self>,
504 rhs: burn_tensor::ops::IntElem<Self>,
505 ) -> IntTensor<Self> {
506 TchOps::bitwise_right_shift_scalar(lhs, rhs)
507 }
508
509 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
510 let kind = dtype.into_kind();
515
516 if tensor.tensor.kind() == kind {
517 tensor
518 } else {
519 TchTensor::new(tensor.tensor.to_kind(kind))
520 }
521 }
522
523 fn int_unfold(
524 tensor: IntTensor<Self>,
525 dim: usize,
526 size: usize,
527 step: usize,
528 ) -> IntTensor<Self> {
529 TchOps::unfold(tensor, dim, size, step)
530 }
531}