burn_tch/ops/
int_tensor.rs

1use std::ops::Range;
2
3use burn_tensor::{
4    Distribution, IntDType, Shape, TensorData, TensorMetadata,
5    backend::Backend,
6    ops::{FloatTensorOps, IntTensor, IntTensorOps},
7};
8
9use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
10
11use super::TchOps;
12
13impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
14    fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
15        match data.dtype {
16            burn_tensor::DType::I64 => TchTensor::from_data::<i64>(data, (*device).into()),
17            _ => unimplemented!("Unsupported dtype for `int_from_data`"),
18        }
19    }
20
21    fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
22        TchOps::repeat_dim(tensor, dim, times)
23    }
24
25    async fn int_into_data(tensor: TchTensor) -> TensorData {
26        let shape = tensor.shape();
27        let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
28        let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
29        TensorData::new(values.unwrap(), shape)
30    }
31
32    fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
33        TchOps::to_device(tensor, device)
34    }
35
36    fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
37        TchOps::reshape(tensor, shape)
38    }
39
40    fn int_device(tensor: &TchTensor) -> LibTorchDevice {
41        tensor.tensor.device().into()
42    }
43
44    fn int_empty(
45        shape: Shape,
46        device: &<LibTorch<E> as Backend>::Device,
47        dtype: IntDType,
48    ) -> TchTensor {
49        let tensor = tch::Tensor::empty(
50            TchShape::from(shape).dims,
51            (dtype.into_kind(), (*device).into()),
52        );
53
54        TchTensor::new(tensor)
55    }
56
57    fn int_slice(tensor: TchTensor, slices: &[burn_tensor::Slice]) -> TchTensor {
58        TchOps::slice_with_steps(tensor, slices)
59    }
60
61    fn int_slice_assign(
62        tensor: TchTensor,
63        slices: &[burn_tensor::Slice],
64        value: TchTensor,
65    ) -> TchTensor {
66        TchOps::slice_assign(tensor, slices, value)
67    }
68
69    fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
70        TchOps::cat(tensors, dim)
71    }
72
73    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
74        let lhs = Self::int_into_float(lhs);
75        let rhs = Self::int_into_float(rhs);
76        let out = lhs.tensor.f_matmul(&rhs.tensor).unwrap();
77        Self::float_into_int(TchTensor::new(out))
78    }
79
80    fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
81        TchOps::equal(lhs, rhs)
82    }
83
84    fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
85        TchOps::equal_elem(lhs, rhs)
86    }
87
88    fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
89        TchOps::greater(lhs, rhs)
90    }
91
92    fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
93        TchOps::greater_elem(lhs, rhs)
94    }
95
96    fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
97        TchOps::greater_equal(lhs, rhs)
98    }
99
100    fn int_greater_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
101        TchOps::greater_equal_elem(lhs, rhs)
102    }
103
104    fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
105        TchOps::lower(lhs, rhs)
106    }
107
108    fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
109        TchOps::lower_elem(lhs, rhs)
110    }
111
112    fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
113        TchOps::lower_equal(lhs, rhs)
114    }
115
116    fn int_lower_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
117        TchOps::lower_equal_elem(lhs, rhs)
118    }
119
120    fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
121        TchOps::add(lhs, rhs)
122    }
123
124    fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
125        lhs.unary_ops(
126            |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
127            |tensor| tensor.f_add_scalar(rhs).unwrap(),
128        )
129    }
130
131    fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
132        TchOps::sub(lhs, rhs)
133    }
134
135    fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
136        lhs.unary_ops(
137            |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
138            |tensor| tensor.f_sub_scalar(rhs).unwrap(),
139        )
140    }
141
142    fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
143        TchOps::mul(lhs, rhs)
144    }
145
146    fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
147        lhs.unary_ops(
148            |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
149            |tensor| tensor.f_mul_scalar(rhs).unwrap(),
150        )
151    }
152
153    fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
154        let dtype = lhs.tensor.kind();
155        let copy = false;
156        let non_blocking = true;
157        let lhs: TchTensor =
158            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
159        let rhs: TchTensor =
160            TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
161
162        let out = TchOps::div(lhs, rhs);
163
164        TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
165    }
166
167    fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
168        let dtype = lhs.tensor.kind();
169        let copy = false;
170        let non_blocking = true;
171        let lhs: TchTensor =
172            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
173
174        let out: TchTensor = lhs.unary_ops(
175            |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
176            |tensor| tensor.f_div_scalar(rhs).unwrap(),
177        );
178
179        TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
180    }
181
182    fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
183        let dtype = lhs.tensor.kind();
184        let copy = false;
185        let non_blocking = true;
186        let lhs: TchTensor =
187            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
188        let rhs: TchTensor =
189            TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
190
191        let out = TchOps::remainder(lhs, rhs);
192
193        TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))
194    }
195
196    fn int_remainder_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
197        lhs.unary_ops(
198            |tensor| tensor.f_remainder(rhs).unwrap(),
199            |tensor| tensor.f_remainder(rhs).unwrap(),
200        )
201    }
202
203    fn int_neg(tensor: TchTensor) -> TchTensor {
204        Self::int_mul_scalar(tensor, -1)
205    }
206
207    fn int_zeros(
208        shape: Shape,
209        device: &<LibTorch<E> as Backend>::Device,
210        dtype: IntDType,
211    ) -> TchTensor {
212        let shape = TchShape::from(shape);
213        let device: tch::Device = (*device).into();
214
215        TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
216    }
217
218    fn int_ones(
219        shape: Shape,
220        device: &<LibTorch<E> as Backend>::Device,
221        dtype: IntDType,
222    ) -> TchTensor {
223        let shape = TchShape::from(shape);
224        let device: tch::Device = (*device).into();
225
226        TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
227    }
228
229    fn int_full(
230        shape: Shape,
231        fill_value: i64,
232        device: &<LibTorch<E> as Backend>::Device,
233        dtype: IntDType,
234    ) -> TchTensor {
235        let shape = TchShape::from(shape);
236        let device: tch::Device = (*device).into();
237
238        TchTensor::new(tch::Tensor::full(
239            shape.dims,
240            fill_value,
241            (dtype.into_kind(), device),
242        ))
243    }
244
245    fn int_sum(tensor: TchTensor) -> TchTensor {
246        TchOps::sum(tensor)
247    }
248
249    fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
250        TchOps::sum_dim(tensor, dim)
251    }
252
253    fn int_prod(tensor: TchTensor) -> TchTensor {
254        TchOps::prod(tensor)
255    }
256
257    fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
258        TchOps::prod_dim(tensor, dim)
259    }
260
261    fn int_mean(tensor: TchTensor) -> TchTensor {
262        let dtype = tensor.tensor.kind();
263        let tensor: TchTensor =
264            TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
265        let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
266
267        TchTensor::new(output.tensor.to_dtype(dtype, true, false))
268    }
269
270    fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
271        let dtype = tensor.tensor.kind();
272        let tensor: TchTensor =
273            TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
274
275        let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
276
277        TchTensor::new(output.tensor.to_dtype(dtype, true, false))
278    }
279
280    fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
281        TchOps::cumsum(tensor, dim)
282    }
283
284    fn int_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
285        TchOps::cumprod(tensor, dim)
286    }
287
288    fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
289        TchOps::cummin(tensor, dim)
290    }
291
292    fn int_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
293        TchOps::cummax(tensor, dim)
294    }
295
296    fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
297        TchOps::gather(dim, tensor, indices)
298    }
299
300    fn int_scatter(
301        dim: usize,
302        tensor: TchTensor,
303        indices: TchTensor,
304        value: TchTensor,
305    ) -> TchTensor {
306        TchOps::scatter(dim, tensor, indices, value)
307    }
308
309    fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
310        TchOps::index_select_dim(tensor, dim, indices)
311    }
312
313    fn int_select_assign(
314        tensor: TchTensor,
315        dim: usize,
316        indices: TchTensor,
317        value: TchTensor,
318    ) -> TchTensor {
319        TchOps::select_assign(tensor, dim, indices, value)
320    }
321
322    fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {
323        TchTensor::binary_ops_tensor(
324            tensor,
325            source,
326            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
327            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
328            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
329        )
330    }
331
332    fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: i64) -> TchTensor {
333        tensor.unary_ops(
334            |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
335            |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
336        )
337    }
338
339    fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
340        TchOps::argmax(tensor, dim)
341    }
342
343    fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
344        TchOps::argmin(tensor, dim)
345    }
346
347    fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
348        TchOps::max_dim(tensor, dim)
349    }
350
351    fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
352        TchOps::max_dim_with_indices(tensor, dim)
353    }
354
355    fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
356        TchOps::min_dim(tensor, dim)
357    }
358
359    fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
360        TchOps::min_dim_with_indices(tensor, dim)
361    }
362
363    fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor {
364        TchOps::clamp_min(tensor, min)
365    }
366
367    fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor {
368        TchOps::clamp_max(tensor, max)
369    }
370
371    fn int_clamp(tensor: TchTensor, min: i64, max: i64) -> TchTensor {
372        TchOps::clamp(tensor, min, max)
373    }
374
375    fn int_abs(tensor: TchTensor) -> TchTensor {
376        tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
377    }
378
379    fn int_into_float(tensor: TchTensor) -> TchTensor {
380        let tensor = tensor.tensor.to_kind(E::KIND);
381        TchTensor::new(tensor)
382    }
383
384    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
385        TchOps::swap_dims(tensor, dim1, dim2)
386    }
387
388    fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
389        match distribution {
390            Distribution::Default => TchTensor::new(tch::Tensor::randint_low(
391                0,
392                255,
393                shape.into_iter().map(|i| i as i64).collect::<Vec<_>>(),
394                (tch::Kind::Int64, (*device).into()),
395            )),
396            Distribution::Bernoulli(prob) => {
397                let mut tensor = TchTensor::empty::<i64>(shape, *device);
398                tensor
399                    .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
400                    .unwrap()
401            }
402            Distribution::Uniform(from, to) => TchTensor::new(tch::Tensor::randint_low(
403                from as i64,
404                to as i64,
405                shape.into_iter().map(|i| i as i64).collect::<Vec<_>>(),
406                (tch::Kind::Int64, (*device).into()),
407            )),
408            Distribution::Normal(mean, std) => {
409                let mut tensor = TchTensor::empty::<i64>(shape, *device);
410                tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
411            }
412        }
413    }
414
415    fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor {
416        let device: tch::Device = (*device).into();
417        let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
418
419        if range.start != 0 {
420            tensor = tensor.f_add_scalar_(range.start).unwrap();
421        }
422
423        TchTensor::new(tensor)
424    }
425
426    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
427        TchOps::permute(tensor, axes)
428    }
429
430    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
431        TchOps::flip(tensor, axes)
432    }
433
434    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
435        TchOps::sign(tensor)
436    }
437
438    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
439        TchOps::expand(tensor, shape)
440    }
441
442    fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
443        TchOps::sort(tensor, dim, descending)
444    }
445
446    fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
447        TchOps::argsort(tensor, dim, descending)
448    }
449
450    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
451        TchOps::bitwise_and(lhs, rhs)
452    }
453
454    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
455        TchOps::bitwise_or(lhs, rhs)
456    }
457
458    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
459        TchOps::bitwise_xor(lhs, rhs)
460    }
461
462    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
463        TchOps::bitwise_not(tensor)
464    }
465
466    fn bitwise_and_scalar(
467        lhs: IntTensor<Self>,
468        rhs: burn_tensor::ops::IntElem<Self>,
469    ) -> IntTensor<Self> {
470        TchOps::bitwise_and_scalar(lhs, rhs)
471    }
472
473    fn bitwise_or_scalar(
474        lhs: IntTensor<Self>,
475        rhs: burn_tensor::ops::IntElem<Self>,
476    ) -> IntTensor<Self> {
477        TchOps::bitwise_or_scalar(lhs, rhs)
478    }
479
480    fn bitwise_xor_scalar(
481        lhs: IntTensor<Self>,
482        rhs: burn_tensor::ops::IntElem<Self>,
483    ) -> IntTensor<Self> {
484        TchOps::bitwise_xor_scalar(lhs, rhs)
485    }
486
487    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
488        TchOps::bitwise_left_shift(lhs, rhs)
489    }
490
491    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
492        TchOps::bitwise_right_shift(lhs, rhs)
493    }
494
495    fn bitwise_left_shift_scalar(
496        lhs: IntTensor<Self>,
497        rhs: burn_tensor::ops::IntElem<Self>,
498    ) -> IntTensor<Self> {
499        TchOps::bitwise_left_shift_scalar(lhs, rhs)
500    }
501
502    fn bitwise_right_shift_scalar(
503        lhs: IntTensor<Self>,
504        rhs: burn_tensor::ops::IntElem<Self>,
505    ) -> IntTensor<Self> {
506        TchOps::bitwise_right_shift_scalar(lhs, rhs)
507    }
508
509    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
510        // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
511        // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
512
513        // Type promotion is not automatic on all backends so this behavior might differ
514        let kind = dtype.into_kind();
515
516        if tensor.tensor.kind() == kind {
517            tensor
518        } else {
519            TchTensor::new(tensor.tensor.to_kind(kind))
520        }
521    }
522
523    fn int_unfold(
524        tensor: IntTensor<Self>,
525        dim: usize,
526        size: usize,
527        step: usize,
528    ) -> IntTensor<Self> {
529        TchOps::unfold(tensor, dim, size, step)
530    }
531}