burn_tch/ops/
tensor.rs

1use super::TchOps;
2use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_tensor::ops::{BoolTensor, FloatTensor};
4use burn_tensor::{
5    DType, Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata,
6    backend::Backend,
7    ops::{FloatTensorOps, IntTensor},
8};
9use half::{bf16, f16};
10
11impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
12    fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
13        match data.dtype {
14            DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
15            DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),
16            DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),
17            DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),
18            _ => unimplemented!("Unsupported dtype for `float_from_data`"),
19        }
20    }
21
22    fn float_random(
23        shape: Shape,
24        distribution: Distribution,
25        device: &LibTorchDevice,
26    ) -> TchTensor {
27        match distribution {
28            Distribution::Default => {
29                let mut tensor = TchTensor::empty::<E>(shape, *device);
30                tensor
31                    .mut_ops(|tensor| tensor.rand_like_out(tensor))
32                    .unwrap()
33            }
34            Distribution::Bernoulli(prob) => {
35                let mut tensor = TchTensor::empty::<E>(shape, *device);
36                tensor
37                    .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
38                    .unwrap()
39            }
40            Distribution::Uniform(from, to) => {
41                let mut tensor = TchTensor::empty::<E>(shape, *device);
42                tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
43            }
44            Distribution::Normal(mean, std) => {
45                let mut tensor = TchTensor::empty::<E>(shape, *device);
46                tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
47            }
48        }
49    }
50
51    fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
52        TchOps::repeat_dim(tensor, dim, times)
53    }
54
55    fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
56        let shape = TchShape::from(shape);
57        let device: tch::Device = (*device).into();
58
59        TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
60    }
61
62    fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
63        let shape = TchShape::from(shape);
64        let device: tch::Device = (*device).into();
65
66        TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
67    }
68
69    async fn float_into_data(tensor: TchTensor) -> TensorData {
70        let shape = tensor.shape();
71        let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
72        match tensor.tensor.kind() {
73            tch::Kind::Half => {
74                let values: Vec<f16> = tensor.tensor.try_into().unwrap();
75                TensorData::new(values, shape)
76            }
77            tch::Kind::Float => {
78                let values: Vec<f32> = tensor.tensor.try_into().unwrap();
79                TensorData::new(values, shape)
80            }
81            tch::Kind::Double => {
82                let values: Vec<f64> = tensor.tensor.try_into().unwrap();
83                TensorData::new(values, shape)
84            }
85            tch::Kind::BFloat16 => {
86                let values: Vec<bf16> = tensor.tensor.try_into().unwrap();
87                TensorData::new(values, shape)
88            }
89            _ => panic!("Not a valid float kind"),
90        }
91    }
92
93    fn float_device(tensor: &TchTensor) -> LibTorchDevice {
94        tensor.tensor.device().into()
95    }
96
97    fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
98        TchOps::to_device(tensor, device)
99    }
100
101    fn float_empty(
102        shape: Shape,
103        device: &<LibTorch<E> as Backend>::Device,
104        dtype: FloatDType,
105    ) -> TchTensor {
106        let tensor = tch::Tensor::empty(
107            TchShape::from(shape).dims,
108            (dtype.into_kind(), (*device).into()),
109        );
110
111        TchTensor::new(tensor)
112    }
113
114    fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
115        TchOps::add(lhs, rhs)
116    }
117
118    fn float_add_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
119        let rhs: f64 = rhs.elem();
120
121        lhs.unary_ops(
122            |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
123            |tensor| tensor.f_add_scalar(rhs).unwrap(),
124        )
125    }
126
127    fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
128        TchOps::sub(lhs, rhs)
129    }
130
131    fn float_sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
132        let rhs: f64 = rhs.elem();
133
134        lhs.unary_ops(
135            |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
136            |tensor| tensor.f_sub_scalar(rhs).unwrap(),
137        )
138    }
139
140    fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
141        TchOps::mul(lhs, rhs)
142    }
143
144    fn float_mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
145        let rhs: f64 = rhs.elem();
146
147        lhs.unary_ops(
148            |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
149            |tensor| tensor.f_mul_scalar(rhs).unwrap(),
150        )
151    }
152
153    fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
154        TchOps::div(lhs, rhs)
155    }
156
157    fn float_div_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
158        let rhs: f64 = rhs.elem();
159
160        lhs.unary_ops(
161            |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
162            |tensor| tensor.f_div_scalar(rhs).unwrap(),
163        )
164    }
165
166    fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
167        TchOps::remainder(lhs, rhs)
168    }
169
170    fn float_remainder_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
171        let rhs: f64 = rhs.elem();
172
173        lhs.unary_ops(
174            |tensor| tensor.f_remainder(rhs).unwrap(),
175            |tensor| tensor.f_remainder(rhs).unwrap(),
176        )
177    }
178
179    fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
180        let tensor = lhs.tensor.matmul(&rhs.tensor);
181        TchTensor::new(tensor)
182    }
183
184    fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {
185        let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);
186        TchTensor::new(tensor)
187    }
188
189    fn float_neg(tensor: TchTensor) -> TchTensor {
190        Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
191    }
192
193    fn float_recip(tensor: TchTensor) -> TchTensor {
194        TchTensor::new(tensor.tensor.reciprocal())
195    }
196
197    fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
198        TchOps::swap_dims(tensor, dim1, dim2)
199    }
200
201    fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
202        TchOps::reshape(tensor, shape)
203    }
204
205    fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
206        TchOps::gather(dim, tensor, indices)
207    }
208
209    fn float_scatter(
210        dim: usize,
211        tensor: TchTensor,
212        indices: TchTensor,
213        value: TchTensor,
214    ) -> TchTensor {
215        TchOps::scatter(dim, tensor, indices, value)
216    }
217
218    fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
219        TchOps::index_select_dim(tensor, dim, indices)
220    }
221
222    fn float_select_assign(
223        tensor: TchTensor,
224        dim: usize,
225        indices: TchTensor,
226        value: TchTensor,
227    ) -> TchTensor {
228        TchOps::select_assign(tensor, dim, indices, value)
229    }
230
231    fn float_slice(tensor: TchTensor, slices: &[burn_tensor::Slice]) -> TchTensor {
232        TchOps::slice_with_steps(tensor, slices)
233    }
234
235    fn float_slice_assign(
236        tensor: TchTensor,
237        slices: &[burn_tensor::Slice],
238        value: TchTensor,
239    ) -> TchTensor {
240        TchOps::slice_assign(tensor, slices, value)
241    }
242
243    fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
244        let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
245
246        TchTensor::new(output)
247    }
248
249    fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: E) -> TchTensor {
250        let value: f64 = value.elem();
251
252        tensor.unary_ops(
253            |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
254            |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
255        )
256    }
257
258    fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
259        TchOps::equal(lhs, rhs)
260    }
261
262    fn float_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
263        TchOps::equal_elem(lhs, rhs.elem::<f64>())
264    }
265
266    fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
267        TchOps::greater(lhs, rhs)
268    }
269
270    fn float_greater_elem(lhs: TchTensor, rhs: E) -> TchTensor {
271        TchOps::greater_elem(lhs, rhs.elem::<f64>())
272    }
273
274    fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
275        TchOps::greater_equal(lhs, rhs)
276    }
277
278    fn float_greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
279        TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
280    }
281
282    fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
283        TchOps::lower(lhs, rhs)
284    }
285
286    fn float_lower_elem(lhs: TchTensor, rhs: E) -> TchTensor {
287        TchOps::lower_elem(lhs, rhs.elem::<f64>())
288    }
289
290    fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
291        TchOps::lower_equal(lhs, rhs)
292    }
293
294    fn float_lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
295        TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
296    }
297
298    fn float_mean(tensor: TchTensor) -> TchTensor {
299        TchOps::mean(tensor)
300    }
301
302    fn float_sum(tensor: TchTensor) -> TchTensor {
303        TchOps::sum(tensor)
304    }
305
306    fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
307        TchOps::sum_dim(tensor, dim)
308    }
309
310    fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
311        TchOps::mean_dim(tensor, dim)
312    }
313
314    fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
315        TchOps::cumsum(tensor, dim)
316    }
317
318    fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
319        TchOps::cumprod(tensor, dim)
320    }
321
322    fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
323        TchOps::cummin(tensor, dim)
324    }
325
326    fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
327        TchOps::cummax(tensor, dim)
328    }
329
330    fn float_prod(tensor: TchTensor) -> TchTensor {
331        TchOps::prod(tensor)
332    }
333
334    fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
335        TchOps::prod_dim(tensor, dim)
336    }
337
338    fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
339        TchOps::argmax(tensor, dim)
340    }
341
342    fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
343        TchOps::argmin(tensor, dim)
344    }
345
346    fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
347        TchOps::max_dim(tensor, dim)
348    }
349
350    fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
351        TchOps::max_dim_with_indices(tensor, dim)
352    }
353
354    fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
355        TchOps::min_dim(tensor, dim)
356    }
357
358    fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
359        TchOps::min_dim_with_indices(tensor, dim)
360    }
361
362    fn float_exp(tensor: TchTensor) -> TchTensor {
363        tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
364    }
365
366    fn float_log(tensor: TchTensor) -> TchTensor {
367        tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
368    }
369
370    fn float_log1p(tensor: TchTensor) -> TchTensor {
371        tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
372    }
373
374    fn float_powf_scalar_impl(tensor: TchTensor, value: f32) -> TchTensor {
375        tensor.unary_ops(
376            |mut tensor| tensor.f_pow_(value as f64).unwrap(),
377            |tensor| tensor.pow_tensor_scalar(value as f64),
378        )
379    }
380
381    fn float_sqrt(tensor: TchTensor) -> TchTensor {
382        tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
383    }
384
385    fn float_abs(tensor: TchTensor) -> TchTensor {
386        tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
387    }
388
389    fn float_cos(tensor: TchTensor) -> TchTensor {
390        tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
391    }
392
393    fn float_sin(tensor: TchTensor) -> TchTensor {
394        tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
395    }
396
397    fn float_tanh(tensor: TchTensor) -> TchTensor {
398        tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
399    }
400
401    fn float_round(tensor: TchTensor) -> TchTensor {
402        tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
403    }
404
405    fn float_floor(tensor: TchTensor) -> TchTensor {
406        tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
407    }
408
409    fn float_ceil(tensor: TchTensor) -> TchTensor {
410        tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
411    }
412
413    fn float_trunc(tensor: TchTensor) -> TchTensor {
414        tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())
415    }
416
417    fn float_erf(tensor: TchTensor) -> TchTensor {
418        tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
419    }
420
421    fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
422        TchOps::cat(tensors, dim)
423    }
424
425    fn float_clamp_min(tensor: TchTensor, min: E) -> TchTensor {
426        TchOps::clamp_min(tensor, min.elem::<f64>())
427    }
428
429    fn float_clamp_max(tensor: TchTensor, max: <LibTorch<E> as Backend>::FloatElem) -> TchTensor {
430        TchOps::clamp_max(tensor, max.elem::<f64>())
431    }
432
433    fn float_clamp(
434        tensor: TchTensor,
435        min: <LibTorch<E> as Backend>::FloatElem,
436        max: <LibTorch<E> as Backend>::FloatElem,
437    ) -> TchTensor {
438        TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
439    }
440
441    fn float_into_int(tensor: TchTensor) -> TchTensor {
442        let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
443        TchTensor::new(tensor)
444    }
445
446    fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
447        TchOps::powf(lhs, rhs)
448    }
449
450    fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
451        TchOps::permute(tensor, axes)
452    }
453
454    fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
455        TchOps::flip(tensor, axes)
456    }
457
458    fn float_sign(tensor: TchTensor) -> TchTensor {
459        TchOps::sign(tensor)
460    }
461
462    fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
463        TchOps::expand(tensor, shape)
464    }
465
466    fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
467        TchOps::sort(tensor, dim, descending)
468    }
469
470    fn float_sort_with_indices(
471        tensor: TchTensor,
472        dim: usize,
473        descending: bool,
474    ) -> (TchTensor, TchTensor) {
475        TchOps::sort_with_indices(tensor, dim, descending)
476    }
477
478    fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {
479        TchOps::argsort(tensor, dim, descending)
480    }
481
482    fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
483        // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
484        // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
485
486        // Type promotion is not automatic on all backends so this behavior might differ
487        let kind = match dtype {
488            FloatDType::F64 => tch::Kind::Double,
489            FloatDType::F32 => tch::Kind::Float,
490            FloatDType::Flex32 => tch::Kind::Float,
491            FloatDType::F16 => tch::Kind::Half,
492            FloatDType::BF16 => tch::Kind::BFloat16,
493        };
494
495        if tensor.tensor.kind() == kind {
496            tensor
497        } else {
498            TchTensor::new(tensor.tensor.to_kind(kind))
499        }
500    }
501
502    fn float_unfold(
503        tensor: FloatTensor<Self>,
504        dim: usize,
505        size: usize,
506        step: usize,
507    ) -> FloatTensor<Self> {
508        TchOps::unfold(tensor, dim, size, step)
509    }
510
511    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
512        TchTensor::new(tensor.tensor.isnan())
513    }
514
515    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
516        TchTensor::new(tensor.tensor.isinf())
517    }
518}