Skip to main content

burn_tch/ops/
tensor.rs

1use super::TchOps;
2use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_backend::backend::ExecutionError;
4use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
5use burn_backend::{BoolDType, IntDType, Scalar, bf16, f16};
6use burn_backend::{
7    DType, Distribution, FloatDType, Shape, TensorData, TensorMetadata, ops::FloatTensorOps,
8};
9
10impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
11    fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
12        match data.dtype {
13            DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
14            DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),
15            DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),
16            DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),
17            _ => unimplemented!("Unsupported dtype for `float_from_data`"),
18        }
19    }
20
21    fn float_random(
22        shape: Shape,
23        distribution: Distribution,
24        device: &LibTorchDevice,
25        dtype: FloatDType,
26    ) -> TchTensor {
27        match distribution {
28            Distribution::Default => {
29                let mut tensor = TchTensor::empty(shape, *device, dtype.into());
30                tensor
31                    .mut_ops(|tensor| tensor.rand_like_out(tensor))
32                    .unwrap()
33            }
34            Distribution::Bernoulli(prob) => {
35                let mut tensor = TchTensor::empty(shape, *device, dtype.into());
36                tensor
37                    .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
38                    .unwrap()
39            }
40            Distribution::Uniform(from, to) => {
41                let mut tensor = TchTensor::empty(shape, *device, dtype.into());
42                tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
43            }
44            Distribution::Normal(mean, std) => {
45                let mut tensor = TchTensor::empty(shape, *device, dtype.into());
46                tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
47            }
48        }
49    }
50
51    fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
52        TchOps::repeat_dim(tensor, dim, times)
53    }
54
55    fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
56        let shape = TchShape::from(shape);
57        let device: tch::Device = (*device).into();
58
59        TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
60    }
61
62    fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
63        let shape = TchShape::from(shape);
64        let device: tch::Device = (*device).into();
65
66        TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
67    }
68
69    async fn float_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
70        let shape = tensor.shape();
71        let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
72        Ok(match tensor.tensor.kind() {
73            tch::Kind::Half => {
74                let values = Vec::<f16>::try_from(&tensor).unwrap();
75                TensorData::new(values, shape)
76            }
77            tch::Kind::Float => {
78                let values = Vec::<f32>::try_from(&tensor).unwrap();
79                TensorData::new(values, shape)
80            }
81            tch::Kind::Double => {
82                let values = Vec::<f64>::try_from(&tensor).unwrap();
83                TensorData::new(values, shape)
84            }
85            tch::Kind::BFloat16 => {
86                let values = Vec::<bf16>::try_from(&tensor).unwrap();
87                TensorData::new(values, shape)
88            }
89            _ => panic!("Not a valid float kind"),
90        })
91    }
92
93    fn float_device(tensor: &TchTensor) -> LibTorchDevice {
94        tensor.tensor.device().into()
95    }
96
97    fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
98        TchOps::to_device(tensor, device)
99    }
100
101    fn float_empty(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
102        let tensor = tch::Tensor::empty(
103            TchShape::from(shape).dims,
104            (dtype.into_kind(), (*device).into()),
105        );
106
107        TchTensor::new(tensor)
108    }
109
110    fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
111        TchOps::add(lhs, rhs)
112    }
113
114    fn float_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
115        let rhs: f64 = rhs.elem();
116
117        lhs.unary_ops(
118            |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
119            |tensor| tensor.f_add_scalar(rhs).unwrap(),
120        )
121    }
122
123    fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
124        TchOps::sub(lhs, rhs)
125    }
126
127    fn float_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
128        let rhs: f64 = rhs.elem();
129
130        lhs.unary_ops(
131            |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
132            |tensor| tensor.f_sub_scalar(rhs).unwrap(),
133        )
134    }
135
136    fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
137        TchOps::mul(lhs, rhs)
138    }
139
140    fn float_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
141        let rhs: f64 = rhs.elem();
142
143        lhs.unary_ops(
144            |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
145            |tensor| tensor.f_mul_scalar(rhs).unwrap(),
146        )
147    }
148
149    fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
150        TchOps::div(lhs, rhs)
151    }
152
153    fn float_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
154        let rhs: f64 = rhs.elem();
155
156        lhs.unary_ops(
157            |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
158            |tensor| tensor.f_div_scalar(rhs).unwrap(),
159        )
160    }
161
162    fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
163        TchOps::remainder(lhs, rhs)
164    }
165
166    fn float_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
167        let rhs: f64 = rhs.elem();
168
169        lhs.unary_ops(
170            |tensor| tensor.f_remainder(rhs).unwrap(),
171            |tensor| tensor.f_remainder(rhs).unwrap(),
172        )
173    }
174
175    fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
176        let tensor = lhs.tensor.matmul(&rhs.tensor);
177        TchTensor::new(tensor)
178    }
179
180    fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {
181        let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);
182        TchTensor::new(tensor)
183    }
184
185    fn float_recip(tensor: TchTensor) -> TchTensor {
186        TchTensor::new(tensor.tensor.reciprocal())
187    }
188
189    fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
190        TchOps::swap_dims(tensor, dim1, dim2)
191    }
192
193    fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
194        TchOps::reshape(tensor, shape)
195    }
196
197    fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
198        TchOps::gather(dim, tensor, indices)
199    }
200
201    fn float_scatter_add(
202        dim: usize,
203        tensor: TchTensor,
204        indices: TchTensor,
205        value: TchTensor,
206    ) -> TchTensor {
207        TchOps::scatter(dim, tensor, indices, value)
208    }
209
210    fn float_scatter_nd(
211        data: TchTensor,
212        indices: TchTensor,
213        values: TchTensor,
214        reduction: burn_backend::tensor::IndexingUpdateOp,
215    ) -> TchTensor {
216        TchOps::scatter_nd(data, indices, values, reduction)
217    }
218
219    fn float_gather_nd(data: TchTensor, indices: TchTensor) -> TchTensor {
220        TchOps::gather_nd(data, indices)
221    }
222
223    fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
224        TchOps::index_select_dim(tensor, dim, indices)
225    }
226
227    fn float_select_add(
228        tensor: TchTensor,
229        dim: usize,
230        indices: TchTensor,
231        value: TchTensor,
232    ) -> TchTensor {
233        TchOps::select_assign(tensor, dim, indices, value)
234    }
235
236    fn float_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
237        TchOps::slice_with_steps(tensor, slices)
238    }
239
240    fn float_slice_assign(
241        tensor: TchTensor,
242        slices: &[burn_backend::Slice],
243        value: TchTensor,
244    ) -> TchTensor {
245        TchOps::slice_assign(tensor, slices, value)
246    }
247
248    fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
249        let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
250
251        TchTensor::new(output)
252    }
253
254    fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {
255        let value: f64 = value.elem();
256
257        tensor.unary_ops(
258            |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
259            |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
260        )
261    }
262
263    fn float_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
264        TchOps::equal(lhs, rhs)
265    }
266
267    fn float_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
268        TchOps::equal_elem(lhs, rhs.elem::<f64>())
269    }
270
271    fn float_greater(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
272        TchOps::greater(lhs, rhs)
273    }
274
275    fn float_greater_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
276        TchOps::greater_elem(lhs, rhs.elem::<f64>())
277    }
278
279    fn float_greater_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
280        TchOps::greater_equal(lhs, rhs)
281    }
282
283    fn float_greater_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
284        TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
285    }
286
287    fn float_lower(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
288        TchOps::lower(lhs, rhs)
289    }
290
291    fn float_lower_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
292        TchOps::lower_elem(lhs, rhs.elem::<f64>())
293    }
294
295    fn float_lower_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
296        TchOps::lower_equal(lhs, rhs)
297    }
298
299    fn float_lower_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
300        TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
301    }
302
303    fn float_mean(tensor: TchTensor) -> TchTensor {
304        TchOps::mean(tensor)
305    }
306
307    fn float_sum(tensor: TchTensor) -> TchTensor {
308        TchOps::sum(tensor)
309    }
310
311    fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
312        TchOps::sum_dim(tensor, dim)
313    }
314
315    fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
316        TchOps::mean_dim(tensor, dim)
317    }
318
319    fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
320        TchOps::cumsum(tensor, dim)
321    }
322
323    fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
324        TchOps::cumprod(tensor, dim)
325    }
326
327    fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
328        TchOps::cummin(tensor, dim)
329    }
330
331    fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
332        TchOps::cummax(tensor, dim)
333    }
334
335    fn float_prod(tensor: TchTensor) -> TchTensor {
336        TchOps::prod(tensor)
337    }
338
339    fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
340        TchOps::prod_dim(tensor, dim)
341    }
342
343    fn float_argmax(tensor: TchTensor, dim: usize, _indices_dtype: IntDType) -> TchTensor {
344        TchOps::argmax(tensor, dim)
345    }
346
347    fn float_argtopk(
348        tensor: TchTensor,
349        dim: usize,
350        k: usize,
351        _indices_dtype: IntDType,
352    ) -> TchTensor {
353        TchOps::argtopk(tensor, dim, k)
354    }
355
356    fn float_topk(tensor: TchTensor, dim: usize, k: usize) -> TchTensor {
357        TchOps::topk(tensor, dim, k)
358    }
359
360    fn float_argmin(tensor: TchTensor, dim: usize, _out_dtype: IntDType) -> TchTensor {
361        TchOps::argmin(tensor, dim)
362    }
363
364    fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
365        TchOps::max_dim(tensor, dim)
366    }
367
368    fn float_max_dim_with_indices(
369        tensor: TchTensor,
370        dim: usize,
371        _indices_dtype: IntDType,
372    ) -> (TchTensor, TchTensor) {
373        TchOps::max_dim_with_indices(tensor, dim)
374    }
375
376    fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
377        TchOps::min_dim(tensor, dim)
378    }
379
380    fn float_min_dim_with_indices(
381        tensor: TchTensor,
382        dim: usize,
383        _indices_dtype: IntDType,
384    ) -> (TchTensor, TchTensor) {
385        TchOps::min_dim_with_indices(tensor, dim)
386    }
387
388    fn float_exp(tensor: TchTensor) -> TchTensor {
389        tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
390    }
391
392    fn float_log(tensor: TchTensor) -> TchTensor {
393        tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
394    }
395
396    fn float_log1p(tensor: TchTensor) -> TchTensor {
397        tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
398    }
399
400    fn float_powf_scalar_impl(tensor: TchTensor, value: Scalar) -> TchTensor {
401        tensor.unary_ops(
402            |mut tensor| tensor.f_pow_(value.elem::<f64>()).unwrap(),
403            |tensor| tensor.pow_tensor_scalar(value.elem::<f64>()),
404        )
405    }
406
407    fn float_sqrt(tensor: TchTensor) -> TchTensor {
408        tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
409    }
410
411    fn float_abs(tensor: TchTensor) -> TchTensor {
412        tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
413    }
414
415    fn float_cos(tensor: TchTensor) -> TchTensor {
416        tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
417    }
418
419    fn float_cosh(tensor: TchTensor) -> TchTensor {
420        tensor.unary_ops(|mut tensor| tensor.cosh_(), |tensor| tensor.cosh())
421    }
422
423    fn float_sin(tensor: TchTensor) -> TchTensor {
424        tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
425    }
426
427    fn float_sinh(tensor: TchTensor) -> TchTensor {
428        tensor.unary_ops(|mut tensor| tensor.sinh_(), |tensor| tensor.sinh())
429    }
430
431    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
432        tensor.unary_ops(|mut tensor| tensor.tan_(), |tensor| tensor.tan())
433    }
434
435    fn float_tanh(tensor: TchTensor) -> TchTensor {
436        tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
437    }
438
439    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
440        tensor.unary_ops(|mut tensor| tensor.acos_(), |tensor| tensor.acos())
441    }
442
443    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
444        tensor.unary_ops(|mut tensor| tensor.acosh_(), |tensor| tensor.acosh())
445    }
446
447    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
448        tensor.unary_ops(|mut tensor| tensor.asin_(), |tensor| tensor.asin())
449    }
450
451    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
452        tensor.unary_ops(|mut tensor| tensor.asinh_(), |tensor| tensor.asinh())
453    }
454
455    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
456        tensor.unary_ops(|mut tensor| tensor.atan_(), |tensor| tensor.atan())
457    }
458
459    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
460        tensor.unary_ops(|mut tensor| tensor.atanh_(), |tensor| tensor.atanh())
461    }
462
463    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
464        TchOps::atan2(lhs, rhs)
465    }
466
467    fn float_round(tensor: TchTensor) -> TchTensor {
468        tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
469    }
470
471    fn float_floor(tensor: TchTensor) -> TchTensor {
472        tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
473    }
474
475    fn float_ceil(tensor: TchTensor) -> TchTensor {
476        tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
477    }
478
479    fn float_trunc(tensor: TchTensor) -> TchTensor {
480        tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())
481    }
482
483    fn float_erf(tensor: TchTensor) -> TchTensor {
484        tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
485    }
486
487    fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
488        TchOps::cat(tensors, dim)
489    }
490
491    fn float_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {
492        TchOps::clamp_min(tensor, min.elem::<f64>())
493    }
494
495    fn float_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {
496        TchOps::clamp_max(tensor, max.elem::<f64>())
497    }
498
499    fn float_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {
500        TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
501    }
502
503    fn float_into_int(tensor: TchTensor, _out_dtype: IntDType) -> TchTensor {
504        let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
505        TchTensor::new(tensor)
506    }
507
508    fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
509        TchOps::pow(lhs, rhs)
510    }
511
512    fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
513        TchOps::permute(tensor, axes)
514    }
515
516    fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
517        TchOps::flip(tensor, axes)
518    }
519
520    fn float_sign(tensor: TchTensor) -> TchTensor {
521        TchOps::sign(tensor)
522    }
523
524    fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
525        TchOps::expand(tensor, shape)
526    }
527
528    fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
529        TchOps::sort(tensor, dim, descending)
530    }
531
532    fn float_sort_with_indices(
533        tensor: TchTensor,
534        dim: usize,
535        descending: bool,
536        _indices_dtype: IntDType,
537    ) -> (TchTensor, TchTensor) {
538        TchOps::sort_with_indices(tensor, dim, descending)
539    }
540
541    fn float_argsort(
542        tensor: TchTensor,
543        dim: usize,
544        descending: bool,
545        _out_dtype: IntDType,
546    ) -> IntTensor<Self> {
547        TchOps::argsort(tensor, dim, descending)
548    }
549
550    fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
551        // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
552        // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
553
554        // Type promotion is not automatic on all backends so this behavior might differ
555        let kind = dtype.into_kind();
556
557        if tensor.tensor.kind() == kind {
558            tensor
559        } else {
560            TchTensor::new(tensor.tensor.to_kind(kind))
561        }
562    }
563
564    fn float_unfold(
565        tensor: FloatTensor<Self>,
566        dim: usize,
567        size: usize,
568        step: usize,
569    ) -> FloatTensor<Self> {
570        TchOps::unfold(tensor, dim, size, step)
571    }
572
573    fn float_is_nan(tensor: FloatTensor<Self>, _out_dtype: BoolDType) -> BoolTensor<Self> {
574        TchTensor::new(tensor.tensor.isnan())
575    }
576
577    fn float_is_inf(tensor: FloatTensor<Self>, _out_dtype: BoolDType) -> BoolTensor<Self> {
578        TchTensor::new(tensor.tensor.isinf())
579    }
580}