burn_tch/ops/
int_tensor.rs

1use std::ops::Range;
2
3use burn_tensor::{
4    backend::Backend,
5    ops::{IntTensor, IntTensorOps},
6    Distribution, Shape, TensorData, TensorMetadata,
7};
8
9use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
10
11use super::TchOps;
12
13impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
14    fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
15        TchTensor::from_data::<i64>(data, (*device).into())
16    }
17
18    fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
19        TchOps::repeat_dim(tensor, dim, times)
20    }
21
22    async fn int_into_data(tensor: TchTensor) -> TensorData {
23        let shape = tensor.shape();
24        let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
25        let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
26        TensorData::new(values.unwrap(), shape)
27    }
28
29    fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
30        TchOps::to_device(tensor, device)
31    }
32
33    fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
34        TchOps::reshape(tensor, shape)
35    }
36
37    fn int_device(tensor: &TchTensor) -> LibTorchDevice {
38        tensor.tensor.device().into()
39    }
40
41    fn int_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
42        let tensor = tch::Tensor::empty(
43            TchShape::from(shape).dims,
44            (tch::Kind::Int64, (*device).into()),
45        );
46
47        TchTensor::new(tensor)
48    }
49
50    fn int_slice(tensor: TchTensor, ranges: &[Range<usize>]) -> TchTensor {
51        TchOps::slice(tensor, ranges)
52    }
53
54    fn int_slice_assign(tensor: TchTensor, ranges: &[Range<usize>], value: TchTensor) -> TchTensor {
55        TchOps::slice_assign(tensor, ranges, value)
56    }
57
58    fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
59        TchOps::cat(tensors, dim)
60    }
61
62    fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
63        TchOps::equal(lhs, rhs)
64    }
65
66    fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
67        TchOps::equal_elem(lhs, rhs)
68    }
69
70    fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
71        TchOps::greater(lhs, rhs)
72    }
73
74    fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
75        TchOps::greater_elem(lhs, rhs)
76    }
77
78    fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
79        TchOps::greater_equal(lhs, rhs)
80    }
81
82    fn int_greater_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
83        TchOps::greater_equal_elem(lhs, rhs)
84    }
85
86    fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
87        TchOps::lower(lhs, rhs)
88    }
89
90    fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
91        TchOps::lower_elem(lhs, rhs)
92    }
93
94    fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
95        TchOps::lower_equal(lhs, rhs)
96    }
97
98    fn int_lower_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
99        TchOps::lower_equal_elem(lhs, rhs)
100    }
101
102    fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
103        TchOps::add(lhs, rhs)
104    }
105
106    fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
107        lhs.unary_ops(
108            |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
109            |tensor| tensor.f_add_scalar(rhs).unwrap(),
110        )
111    }
112
113    fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
114        TchOps::sub(lhs, rhs)
115    }
116
117    fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
118        lhs.unary_ops(
119            |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
120            |tensor| tensor.f_sub_scalar(rhs).unwrap(),
121        )
122    }
123
124    fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
125        TchOps::mul(lhs, rhs)
126    }
127
128    fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
129        lhs.unary_ops(
130            |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
131            |tensor| tensor.f_mul_scalar(rhs).unwrap(),
132        )
133    }
134
135    fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
136        let copy = false;
137        let non_blocking = true;
138        let lhs: TchTensor =
139            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
140        let rhs: TchTensor =
141            TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
142
143        let out = TchOps::div(lhs, rhs);
144
145        TchTensor::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
146    }
147
148    fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
149        let copy = false;
150        let non_blocking = true;
151        let lhs: TchTensor =
152            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
153
154        let out: TchTensor = lhs.unary_ops(
155            |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
156            |tensor| tensor.f_div_scalar(rhs).unwrap(),
157        );
158
159        TchTensor::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
160    }
161
162    fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
163        let copy = false;
164        let non_blocking = true;
165        let lhs: TchTensor =
166            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
167        let rhs: TchTensor =
168            TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
169
170        let out = TchOps::remainder(lhs, rhs);
171
172        TchTensor::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
173    }
174
175    fn int_remainder_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
176        lhs.unary_ops(
177            |tensor| tensor.f_remainder(rhs).unwrap(),
178            |tensor| tensor.f_remainder(rhs).unwrap(),
179        )
180    }
181
182    fn int_neg(tensor: TchTensor) -> TchTensor {
183        Self::int_mul_scalar(tensor, -1)
184    }
185
186    fn int_zeros(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
187        let shape = TchShape::from(shape);
188        let device: tch::Device = (*device).into();
189
190        TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device)))
191    }
192
193    fn int_ones(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
194        let shape = TchShape::from(shape);
195        let device: tch::Device = (*device).into();
196
197        TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device)))
198    }
199
200    fn int_full(
201        shape: Shape,
202        fill_value: i64,
203        device: &<LibTorch<E> as Backend>::Device,
204    ) -> TchTensor {
205        let shape = TchShape::from(shape);
206        let device: tch::Device = (*device).into();
207
208        TchTensor::new(tch::Tensor::full(
209            shape.dims,
210            fill_value,
211            (tch::Kind::Int64, device),
212        ))
213    }
214
215    fn int_sum(tensor: TchTensor) -> TchTensor {
216        TchOps::sum(tensor)
217    }
218
219    fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
220        TchOps::sum_dim(tensor, dim)
221    }
222
223    fn int_prod(tensor: TchTensor) -> TchTensor {
224        TchOps::prod(tensor)
225    }
226
227    fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
228        TchOps::prod_dim(tensor, dim)
229    }
230
231    fn int_mean(tensor: TchTensor) -> TchTensor {
232        let tensor: TchTensor =
233            TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
234        let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
235
236        TchTensor::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
237    }
238
239    fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
240        let tensor: TchTensor =
241            TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
242
243        let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
244
245        TchTensor::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
246    }
247
248    fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
249        TchOps::gather(dim, tensor, indices)
250    }
251
252    fn int_scatter(
253        dim: usize,
254        tensor: TchTensor,
255        indices: TchTensor,
256        value: TchTensor,
257    ) -> TchTensor {
258        TchOps::scatter(dim, tensor, indices, value)
259    }
260
261    fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
262        TchOps::index_select_dim(tensor, dim, indices)
263    }
264
265    fn int_select_assign(
266        tensor: TchTensor,
267        dim: usize,
268        indices: TchTensor,
269        value: TchTensor,
270    ) -> TchTensor {
271        TchOps::select_assign(tensor, dim, indices, value)
272    }
273
274    fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {
275        TchTensor::binary_ops_tensor(
276            tensor,
277            source,
278            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
279            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
280            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
281        )
282    }
283
284    fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: i64) -> TchTensor {
285        tensor.unary_ops(
286            |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
287            |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
288        )
289    }
290
291    fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
292        TchOps::argmax(tensor, dim)
293    }
294
295    fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
296        TchOps::argmin(tensor, dim)
297    }
298
299    fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
300        TchOps::max_dim(tensor, dim)
301    }
302
303    fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
304        TchOps::max_dim_with_indices(tensor, dim)
305    }
306
307    fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
308        TchOps::min_dim(tensor, dim)
309    }
310
311    fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
312        TchOps::min_dim_with_indices(tensor, dim)
313    }
314
315    fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor {
316        TchOps::clamp_min(tensor, min)
317    }
318
319    fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor {
320        TchOps::clamp_max(tensor, max)
321    }
322
323    fn int_clamp(tensor: TchTensor, min: i64, max: i64) -> TchTensor {
324        TchOps::clamp(tensor, min, max)
325    }
326
327    fn int_abs(tensor: TchTensor) -> TchTensor {
328        tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
329    }
330
331    fn int_into_float(tensor: TchTensor) -> TchTensor {
332        let tensor = tensor.tensor.to_kind(E::KIND);
333        TchTensor::new(tensor)
334    }
335
336    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
337        TchOps::swap_dims(tensor, dim1, dim2)
338    }
339
340    fn int_narrow(tensor: TchTensor, dim: usize, start: usize, length: usize) -> TchTensor {
341        TchOps::narrow(tensor, dim, start, length)
342    }
343
344    fn int_chunk(tensor: TchTensor, chunks: usize, dim: usize) -> Vec<TchTensor> {
345        TchOps::chunk(tensor, chunks, dim)
346    }
347
348    fn int_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
349        TchOps::split(tensor, split_size, dim)
350    }
351
352    fn int_split_with_sizes(
353        tensor: TchTensor,
354        split_sizes: Vec<usize>,
355        dim: usize,
356    ) -> Vec<TchTensor> {
357        TchOps::split_with_sizes(tensor, split_sizes, dim)
358    }
359
360    fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
361        match distribution {
362            Distribution::Default => {
363                let mut tensor = TchTensor::empty::<i64>(shape, *device);
364                tensor
365                    .mut_ops(|tensor| tensor.uniform_(0.0, 255.0))
366                    .unwrap()
367            }
368            Distribution::Bernoulli(prob) => {
369                let mut tensor = TchTensor::empty::<i64>(shape, *device);
370                tensor
371                    .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
372                    .unwrap()
373            }
374            Distribution::Uniform(from, to) => {
375                let mut tensor = TchTensor::empty::<i64>(shape, *device);
376                tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
377            }
378            Distribution::Normal(mean, std) => {
379                let mut tensor = TchTensor::empty::<i64>(shape, *device);
380                tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
381            }
382        }
383    }
384
385    fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor {
386        let device: tch::Device = (*device).into();
387        let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
388
389        if range.start != 0 {
390            tensor = tensor.f_add_scalar_(range.start).unwrap();
391        }
392
393        TchTensor::new(tensor)
394    }
395
396    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
397        TchOps::permute(tensor, axes)
398    }
399
400    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
401        TchOps::flip(tensor, axes)
402    }
403
404    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
405        TchOps::sign(tensor)
406    }
407
408    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
409        TchOps::expand(tensor, shape)
410    }
411
412    fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
413        TchOps::sort(tensor, dim, descending)
414    }
415
416    fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
417        TchOps::argsort(tensor, dim, descending)
418    }
419}