1use std::ops::Range;
2
3use burn_tensor::{
4 backend::Backend,
5 ops::{IntTensor, IntTensorOps},
6 Distribution, Shape, TensorData, TensorMetadata,
7};
8
9use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
10
11use super::TchOps;
12
13impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
14 fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
15 TchTensor::from_data::<i64>(data, (*device).into())
16 }
17
18 fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
19 TchOps::repeat_dim(tensor, dim, times)
20 }
21
22 async fn int_into_data(tensor: TchTensor) -> TensorData {
23 let shape = tensor.shape();
24 let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
25 let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();
26 TensorData::new(values.unwrap(), shape)
27 }
28
29 fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
30 TchOps::to_device(tensor, device)
31 }
32
33 fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
34 TchOps::reshape(tensor, shape)
35 }
36
37 fn int_device(tensor: &TchTensor) -> LibTorchDevice {
38 tensor.tensor.device().into()
39 }
40
41 fn int_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
42 let tensor = tch::Tensor::empty(
43 TchShape::from(shape).dims,
44 (tch::Kind::Int64, (*device).into()),
45 );
46
47 TchTensor::new(tensor)
48 }
49
50 fn int_slice(tensor: TchTensor, ranges: &[Range<usize>]) -> TchTensor {
51 TchOps::slice(tensor, ranges)
52 }
53
54 fn int_slice_assign(tensor: TchTensor, ranges: &[Range<usize>], value: TchTensor) -> TchTensor {
55 TchOps::slice_assign(tensor, ranges, value)
56 }
57
58 fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
59 TchOps::cat(tensors, dim)
60 }
61
62 fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
63 TchOps::equal(lhs, rhs)
64 }
65
66 fn int_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
67 TchOps::equal_elem(lhs, rhs)
68 }
69
70 fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
71 TchOps::greater(lhs, rhs)
72 }
73
74 fn int_greater_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
75 TchOps::greater_elem(lhs, rhs)
76 }
77
78 fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
79 TchOps::greater_equal(lhs, rhs)
80 }
81
82 fn int_greater_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
83 TchOps::greater_equal_elem(lhs, rhs)
84 }
85
86 fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
87 TchOps::lower(lhs, rhs)
88 }
89
90 fn int_lower_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
91 TchOps::lower_elem(lhs, rhs)
92 }
93
94 fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
95 TchOps::lower_equal(lhs, rhs)
96 }
97
98 fn int_lower_equal_elem(lhs: TchTensor, rhs: i64) -> TchTensor {
99 TchOps::lower_equal_elem(lhs, rhs)
100 }
101
102 fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
103 TchOps::add(lhs, rhs)
104 }
105
106 fn int_add_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
107 lhs.unary_ops(
108 |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
109 |tensor| tensor.f_add_scalar(rhs).unwrap(),
110 )
111 }
112
113 fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
114 TchOps::sub(lhs, rhs)
115 }
116
117 fn int_sub_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
118 lhs.unary_ops(
119 |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
120 |tensor| tensor.f_sub_scalar(rhs).unwrap(),
121 )
122 }
123
124 fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
125 TchOps::mul(lhs, rhs)
126 }
127
128 fn int_mul_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
129 lhs.unary_ops(
130 |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
131 |tensor| tensor.f_mul_scalar(rhs).unwrap(),
132 )
133 }
134
135 fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
136 let copy = false;
137 let non_blocking = true;
138 let lhs: TchTensor =
139 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
140 let rhs: TchTensor =
141 TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
142
143 let out = TchOps::div(lhs, rhs);
144
145 TchTensor::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
146 }
147
148 fn int_div_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
149 let copy = false;
150 let non_blocking = true;
151 let lhs: TchTensor =
152 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
153
154 let out: TchTensor = lhs.unary_ops(
155 |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
156 |tensor| tensor.f_div_scalar(rhs).unwrap(),
157 );
158
159 TchTensor::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
160 }
161
162 fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
163 let copy = false;
164 let non_blocking = true;
165 let lhs: TchTensor =
166 TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
167 let rhs: TchTensor =
168 TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));
169
170 let out = TchOps::remainder(lhs, rhs);
171
172 TchTensor::new(out.tensor.to_dtype(tch::Kind::Int64, non_blocking, copy))
173 }
174
175 fn int_remainder_scalar(lhs: TchTensor, rhs: i64) -> TchTensor {
176 lhs.unary_ops(
177 |tensor| tensor.f_remainder(rhs).unwrap(),
178 |tensor| tensor.f_remainder(rhs).unwrap(),
179 )
180 }
181
182 fn int_neg(tensor: TchTensor) -> TchTensor {
183 Self::int_mul_scalar(tensor, -1)
184 }
185
186 fn int_zeros(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
187 let shape = TchShape::from(shape);
188 let device: tch::Device = (*device).into();
189
190 TchTensor::new(tch::Tensor::zeros(shape.dims, (tch::Kind::Int64, device)))
191 }
192
193 fn int_ones(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
194 let shape = TchShape::from(shape);
195 let device: tch::Device = (*device).into();
196
197 TchTensor::new(tch::Tensor::ones(shape.dims, (tch::Kind::Int64, device)))
198 }
199
200 fn int_full(
201 shape: Shape,
202 fill_value: i64,
203 device: &<LibTorch<E> as Backend>::Device,
204 ) -> TchTensor {
205 let shape = TchShape::from(shape);
206 let device: tch::Device = (*device).into();
207
208 TchTensor::new(tch::Tensor::full(
209 shape.dims,
210 fill_value,
211 (tch::Kind::Int64, device),
212 ))
213 }
214
215 fn int_sum(tensor: TchTensor) -> TchTensor {
216 TchOps::sum(tensor)
217 }
218
219 fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
220 TchOps::sum_dim(tensor, dim)
221 }
222
223 fn int_prod(tensor: TchTensor) -> TchTensor {
224 TchOps::prod(tensor)
225 }
226
227 fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
228 TchOps::prod_dim(tensor, dim)
229 }
230
231 fn int_mean(tensor: TchTensor) -> TchTensor {
232 let tensor: TchTensor =
233 TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
234 let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);
235
236 TchTensor::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
237 }
238
239 fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
240 let tensor: TchTensor =
241 TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));
242
243 let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);
244
245 TchTensor::new(output.tensor.to_dtype(tch::Kind::Int64, true, false))
246 }
247
248 fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
249 TchOps::gather(dim, tensor, indices)
250 }
251
252 fn int_scatter(
253 dim: usize,
254 tensor: TchTensor,
255 indices: TchTensor,
256 value: TchTensor,
257 ) -> TchTensor {
258 TchOps::scatter(dim, tensor, indices, value)
259 }
260
261 fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
262 TchOps::index_select_dim(tensor, dim, indices)
263 }
264
265 fn int_select_assign(
266 tensor: TchTensor,
267 dim: usize,
268 indices: TchTensor,
269 value: TchTensor,
270 ) -> TchTensor {
271 TchOps::select_assign(tensor, dim, indices, value)
272 }
273
274 fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {
275 TchTensor::binary_ops_tensor(
276 tensor,
277 source,
278 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
279 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
280 |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),
281 )
282 }
283
284 fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: i64) -> TchTensor {
285 tensor.unary_ops(
286 |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
287 |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
288 )
289 }
290
291 fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
292 TchOps::argmax(tensor, dim)
293 }
294
295 fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
296 TchOps::argmin(tensor, dim)
297 }
298
299 fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
300 TchOps::max_dim(tensor, dim)
301 }
302
303 fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
304 TchOps::max_dim_with_indices(tensor, dim)
305 }
306
307 fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
308 TchOps::min_dim(tensor, dim)
309 }
310
311 fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
312 TchOps::min_dim_with_indices(tensor, dim)
313 }
314
315 fn int_clamp_min(tensor: TchTensor, min: i64) -> TchTensor {
316 TchOps::clamp_min(tensor, min)
317 }
318
319 fn int_clamp_max(tensor: TchTensor, max: i64) -> TchTensor {
320 TchOps::clamp_max(tensor, max)
321 }
322
323 fn int_clamp(tensor: TchTensor, min: i64, max: i64) -> TchTensor {
324 TchOps::clamp(tensor, min, max)
325 }
326
327 fn int_abs(tensor: TchTensor) -> TchTensor {
328 tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
329 }
330
331 fn int_into_float(tensor: TchTensor) -> TchTensor {
332 let tensor = tensor.tensor.to_kind(E::KIND);
333 TchTensor::new(tensor)
334 }
335
336 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
337 TchOps::swap_dims(tensor, dim1, dim2)
338 }
339
340 fn int_narrow(tensor: TchTensor, dim: usize, start: usize, length: usize) -> TchTensor {
341 TchOps::narrow(tensor, dim, start, length)
342 }
343
344 fn int_chunk(tensor: TchTensor, chunks: usize, dim: usize) -> Vec<TchTensor> {
345 TchOps::chunk(tensor, chunks, dim)
346 }
347
348 fn int_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
349 TchOps::split(tensor, split_size, dim)
350 }
351
352 fn int_split_with_sizes(
353 tensor: TchTensor,
354 split_sizes: Vec<usize>,
355 dim: usize,
356 ) -> Vec<TchTensor> {
357 TchOps::split_with_sizes(tensor, split_sizes, dim)
358 }
359
360 fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {
361 match distribution {
362 Distribution::Default => {
363 let mut tensor = TchTensor::empty::<i64>(shape, *device);
364 tensor
365 .mut_ops(|tensor| tensor.uniform_(0.0, 255.0))
366 .unwrap()
367 }
368 Distribution::Bernoulli(prob) => {
369 let mut tensor = TchTensor::empty::<i64>(shape, *device);
370 tensor
371 .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
372 .unwrap()
373 }
374 Distribution::Uniform(from, to) => {
375 let mut tensor = TchTensor::empty::<i64>(shape, *device);
376 tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
377 }
378 Distribution::Normal(mean, std) => {
379 let mut tensor = TchTensor::empty::<i64>(shape, *device);
380 tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
381 }
382 }
383 }
384
385 fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor {
386 let device: tch::Device = (*device).into();
387 let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
388
389 if range.start != 0 {
390 tensor = tensor.f_add_scalar_(range.start).unwrap();
391 }
392
393 TchTensor::new(tensor)
394 }
395
396 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
397 TchOps::permute(tensor, axes)
398 }
399
400 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
401 TchOps::flip(tensor, axes)
402 }
403
404 fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
405 TchOps::sign(tensor)
406 }
407
408 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
409 TchOps::expand(tensor, shape)
410 }
411
412 fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
413 TchOps::sort(tensor, dim, descending)
414 }
415
416 fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
417 TchOps::argsort(tensor, dim, descending)
418 }
419}