burn_tch/ops/
tensor.rs

1use super::TchOps;
2use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
3use burn_tensor::{
4    backend::Backend,
5    ops::{FloatTensorOps, IntTensor},
6    Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata,
7};
8use half::{bf16, f16};
9use std::ops::Range;
10
11impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
12    fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
13        TchTensor::from_data::<E>(data, (*device).into())
14    }
15
16    fn float_random(
17        shape: Shape,
18        distribution: Distribution,
19        device: &LibTorchDevice,
20    ) -> TchTensor {
21        match distribution {
22            Distribution::Default => {
23                let mut tensor = TchTensor::empty::<E>(shape, *device);
24                tensor
25                    .mut_ops(|tensor| tensor.rand_like_out(tensor))
26                    .unwrap()
27            }
28            Distribution::Bernoulli(prob) => {
29                let mut tensor = TchTensor::empty::<E>(shape, *device);
30                tensor
31                    .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
32                    .unwrap()
33            }
34            Distribution::Uniform(from, to) => {
35                let mut tensor = TchTensor::empty::<E>(shape, *device);
36                tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
37            }
38            Distribution::Normal(mean, std) => {
39                let mut tensor = TchTensor::empty::<E>(shape, *device);
40                tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
41            }
42        }
43    }
44
45    fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
46        TchOps::repeat_dim(tensor, dim, times)
47    }
48
49    fn float_zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor {
50        let shape = TchShape::from(shape);
51        let device: tch::Device = (*device).into();
52
53        TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device)))
54    }
55
56    fn float_ones(shape: Shape, device: &LibTorchDevice) -> TchTensor {
57        let shape = TchShape::from(shape);
58        let device: tch::Device = (*device).into();
59
60        TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device)))
61    }
62
63    async fn float_into_data(tensor: TchTensor) -> TensorData {
64        let shape = tensor.shape();
65        let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
66        match tensor.tensor.kind() {
67            tch::Kind::Half => {
68                let values: Vec<f16> = tensor.tensor.try_into().unwrap();
69                TensorData::new(values, shape)
70            }
71            tch::Kind::Float => {
72                let values: Vec<f32> = tensor.tensor.try_into().unwrap();
73                TensorData::new(values, shape)
74            }
75            tch::Kind::Double => {
76                let values: Vec<f64> = tensor.tensor.try_into().unwrap();
77                TensorData::new(values, shape)
78            }
79            tch::Kind::BFloat16 => {
80                let values: Vec<bf16> = tensor.tensor.try_into().unwrap();
81                TensorData::new(values, shape)
82            }
83            _ => panic!("Not a valid float kind"),
84        }
85    }
86
87    fn float_device(tensor: &TchTensor) -> LibTorchDevice {
88        tensor.tensor.device().into()
89    }
90
91    fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
92        TchOps::to_device(tensor, device)
93    }
94
95    fn float_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
96        let tensor = tch::Tensor::empty(TchShape::from(shape).dims, (E::KIND, (*device).into()));
97
98        TchTensor::new(tensor)
99    }
100
101    fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
102        TchOps::add(lhs, rhs)
103    }
104
105    fn float_add_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
106        let rhs: f64 = rhs.elem();
107
108        lhs.unary_ops(
109            |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
110            |tensor| tensor.f_add_scalar(rhs).unwrap(),
111        )
112    }
113
114    fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
115        TchOps::sub(lhs, rhs)
116    }
117
118    fn float_sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
119        let rhs: f64 = rhs.elem();
120
121        lhs.unary_ops(
122            |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
123            |tensor| tensor.f_sub_scalar(rhs).unwrap(),
124        )
125    }
126
127    fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
128        TchOps::mul(lhs, rhs)
129    }
130
131    fn float_mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
132        let rhs: f64 = rhs.elem();
133
134        lhs.unary_ops(
135            |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
136            |tensor| tensor.f_mul_scalar(rhs).unwrap(),
137        )
138    }
139
140    fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
141        TchOps::div(lhs, rhs)
142    }
143
144    fn float_div_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
145        let rhs: f64 = rhs.elem();
146
147        lhs.unary_ops(
148            |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
149            |tensor| tensor.f_div_scalar(rhs).unwrap(),
150        )
151    }
152
153    fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
154        TchOps::remainder(lhs, rhs)
155    }
156
157    fn float_remainder_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
158        let rhs: f64 = rhs.elem();
159
160        lhs.unary_ops(
161            |tensor| tensor.f_remainder(rhs).unwrap(),
162            |tensor| tensor.f_remainder(rhs).unwrap(),
163        )
164    }
165
166    fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
167        let tensor = lhs.tensor.matmul(&rhs.tensor);
168        TchTensor::new(tensor)
169    }
170
171    fn float_neg(tensor: TchTensor) -> TchTensor {
172        Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
173    }
174
175    fn float_recip(tensor: TchTensor) -> TchTensor {
176        TchTensor::new(tensor.tensor.reciprocal())
177    }
178
179    fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
180        TchOps::swap_dims(tensor, dim1, dim2)
181    }
182
183    fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
184        TchOps::reshape(tensor, shape)
185    }
186
187    fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
188        TchOps::gather(dim, tensor, indices)
189    }
190
191    fn float_scatter(
192        dim: usize,
193        tensor: TchTensor,
194        indices: TchTensor,
195        value: TchTensor,
196    ) -> TchTensor {
197        TchOps::scatter(dim, tensor, indices, value)
198    }
199
200    fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
201        TchOps::index_select_dim(tensor, dim, indices)
202    }
203
204    fn float_select_assign(
205        tensor: TchTensor,
206        dim: usize,
207        indices: TchTensor,
208        value: TchTensor,
209    ) -> TchTensor {
210        TchOps::select_assign(tensor, dim, indices, value)
211    }
212
213    fn float_slice(tensor: TchTensor, ranges: &[Range<usize>]) -> TchTensor {
214        TchOps::slice(tensor, ranges)
215    }
216
217    fn float_slice_assign(
218        tensor: TchTensor,
219        ranges: &[Range<usize>],
220        value: TchTensor,
221    ) -> TchTensor {
222        TchOps::slice_assign(tensor, ranges, value)
223    }
224
225    fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
226        let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
227
228        TchTensor::new(output)
229    }
230
231    fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: E) -> TchTensor {
232        let value: f64 = value.elem();
233
234        tensor.unary_ops(
235            |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
236            |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
237        )
238    }
239
240    fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
241        TchOps::equal(lhs, rhs)
242    }
243
244    fn float_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
245        TchOps::equal_elem(lhs, rhs.elem::<f64>())
246    }
247
248    fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
249        TchOps::greater(lhs, rhs)
250    }
251
252    fn float_greater_elem(lhs: TchTensor, rhs: E) -> TchTensor {
253        TchOps::greater_elem(lhs, rhs.elem::<f64>())
254    }
255
256    fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
257        TchOps::greater_equal(lhs, rhs)
258    }
259
260    fn float_greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
261        TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
262    }
263
264    fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
265        TchOps::lower(lhs, rhs)
266    }
267
268    fn float_lower_elem(lhs: TchTensor, rhs: E) -> TchTensor {
269        TchOps::lower_elem(lhs, rhs.elem::<f64>())
270    }
271
272    fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
273        TchOps::lower_equal(lhs, rhs)
274    }
275
276    fn float_lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
277        TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
278    }
279
280    fn float_mean(tensor: TchTensor) -> TchTensor {
281        TchOps::mean(tensor)
282    }
283
284    fn float_sum(tensor: TchTensor) -> TchTensor {
285        TchOps::sum(tensor)
286    }
287
288    fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
289        TchOps::sum_dim(tensor, dim)
290    }
291
292    fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
293        TchOps::mean_dim(tensor, dim)
294    }
295
296    fn float_prod(tensor: TchTensor) -> TchTensor {
297        TchOps::prod(tensor)
298    }
299
300    fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
301        TchOps::prod_dim(tensor, dim)
302    }
303
304    fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
305        TchOps::argmax(tensor, dim)
306    }
307
308    fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
309        TchOps::argmin(tensor, dim)
310    }
311
312    fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
313        TchOps::max_dim(tensor, dim)
314    }
315
316    fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
317        TchOps::max_dim_with_indices(tensor, dim)
318    }
319
320    fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
321        TchOps::min_dim(tensor, dim)
322    }
323
324    fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
325        TchOps::min_dim_with_indices(tensor, dim)
326    }
327
328    fn float_exp(tensor: TchTensor) -> TchTensor {
329        tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
330    }
331
332    fn float_log(tensor: TchTensor) -> TchTensor {
333        tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
334    }
335
336    fn float_log1p(tensor: TchTensor) -> TchTensor {
337        tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
338    }
339
340    fn float_powf_scalar(tensor: TchTensor, value: f32) -> TchTensor {
341        tensor.unary_ops(
342            |mut tensor| tensor.f_pow_(value as f64).unwrap(),
343            |tensor| tensor.pow_tensor_scalar(value as f64),
344        )
345    }
346
347    fn float_sqrt(tensor: TchTensor) -> TchTensor {
348        tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
349    }
350
351    fn float_abs(tensor: TchTensor) -> TchTensor {
352        tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
353    }
354
355    fn float_cos(tensor: TchTensor) -> TchTensor {
356        tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
357    }
358
359    fn float_sin(tensor: TchTensor) -> TchTensor {
360        tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
361    }
362
363    fn float_tanh(tensor: TchTensor) -> TchTensor {
364        tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
365    }
366
367    fn float_round(tensor: TchTensor) -> TchTensor {
368        tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
369    }
370
371    fn float_floor(tensor: TchTensor) -> TchTensor {
372        tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
373    }
374
375    fn float_ceil(tensor: TchTensor) -> TchTensor {
376        tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
377    }
378
379    fn float_erf(tensor: TchTensor) -> TchTensor {
380        tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
381    }
382
383    fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
384        TchOps::cat(tensors, dim)
385    }
386
387    fn float_clamp_min(tensor: TchTensor, min: E) -> TchTensor {
388        TchOps::clamp_min(tensor, min.elem::<f64>())
389    }
390
391    fn float_clamp_max(tensor: TchTensor, max: <LibTorch<E> as Backend>::FloatElem) -> TchTensor {
392        TchOps::clamp_max(tensor, max.elem::<f64>())
393    }
394
395    fn float_clamp(
396        tensor: TchTensor,
397        min: <LibTorch<E> as Backend>::FloatElem,
398        max: <LibTorch<E> as Backend>::FloatElem,
399    ) -> TchTensor {
400        TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
401    }
402
403    fn float_into_int(tensor: TchTensor) -> TchTensor {
404        let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
405        TchTensor::new(tensor)
406    }
407
408    fn float_narrow(tensor: TchTensor, dim: usize, start: usize, length: usize) -> TchTensor {
409        TchOps::narrow(tensor, dim, start, length)
410    }
411
412    fn float_chunk(tensor: TchTensor, chunks: usize, dim: usize) -> Vec<TchTensor> {
413        TchOps::chunk(tensor, chunks, dim)
414    }
415
416    fn float_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
417        TchOps::split(tensor, split_size, dim)
418    }
419
420    fn float_split_with_sizes(
421        tensor: TchTensor,
422        split_sizes: Vec<usize>,
423        dim: usize,
424    ) -> Vec<TchTensor> {
425        TchOps::split_with_sizes(tensor, split_sizes, dim)
426    }
427
428    fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
429        TchOps::powf(lhs, rhs)
430    }
431
432    fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
433        TchOps::permute(tensor, axes)
434    }
435
436    fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
437        TchOps::flip(tensor, axes)
438    }
439
440    fn float_sign(tensor: TchTensor) -> TchTensor {
441        TchOps::sign(tensor)
442    }
443
444    fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
445        TchOps::expand(tensor, shape)
446    }
447
448    fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
449        TchOps::sort(tensor, dim, descending)
450    }
451
452    fn float_sort_with_indices(
453        tensor: TchTensor,
454        dim: usize,
455        descending: bool,
456    ) -> (TchTensor, TchTensor) {
457        TchOps::sort_with_indices(tensor, dim, descending)
458    }
459
460    fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {
461        TchOps::argsort(tensor, dim, descending)
462    }
463
464    fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
465        // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type
466        // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
467
468        // Type promotion is not automatic on all backends so this behavior might differ
469        let kind = match dtype {
470            FloatDType::F64 => tch::Kind::Double,
471            FloatDType::F32 => tch::Kind::Float,
472            FloatDType::F16 => tch::Kind::Half,
473            FloatDType::BF16 => tch::Kind::BFloat16,
474        };
475
476        if tensor.tensor.kind() == kind {
477            tensor
478        } else {
479            TchTensor::new(tensor.tensor.to_kind(kind))
480        }
481    }
482}