1use super::TchOps;
2use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_backend::backend::ExecutionError;
4use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
5use burn_backend::{
6 DType, Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata,
7 backend::Backend, ops::FloatTensorOps,
8};
9use burn_backend::{bf16, f16};
10
11impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
12 fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
13 match data.dtype {
14 DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
15 DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),
16 DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),
17 DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),
18 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
19 }
20 }
21
22 fn float_random(
23 shape: Shape,
24 distribution: Distribution,
25 device: &LibTorchDevice,
26 ) -> TchTensor {
27 match distribution {
28 Distribution::Default => {
29 let mut tensor = TchTensor::empty::<E>(shape, *device);
30 tensor
31 .mut_ops(|tensor| tensor.rand_like_out(tensor))
32 .unwrap()
33 }
34 Distribution::Bernoulli(prob) => {
35 let mut tensor = TchTensor::empty::<E>(shape, *device);
36 tensor
37 .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
38 .unwrap()
39 }
40 Distribution::Uniform(from, to) => {
41 let mut tensor = TchTensor::empty::<E>(shape, *device);
42 tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
43 }
44 Distribution::Normal(mean, std) => {
45 let mut tensor = TchTensor::empty::<E>(shape, *device);
46 tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
47 }
48 }
49 }
50
51 fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
52 TchOps::repeat_dim(tensor, dim, times)
53 }
54
55 fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
56 let shape = TchShape::from(shape);
57 let device: tch::Device = (*device).into();
58
59 TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
60 }
61
62 fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
63 let shape = TchShape::from(shape);
64 let device: tch::Device = (*device).into();
65
66 TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
67 }
68
69 async fn float_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
70 let shape = tensor.shape();
71 let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
72 Ok(match tensor.tensor.kind() {
73 tch::Kind::Half => {
74 let values = Vec::<f16>::try_from(&tensor).unwrap();
75 TensorData::new(values, shape)
76 }
77 tch::Kind::Float => {
78 let values = Vec::<f32>::try_from(&tensor).unwrap();
79 TensorData::new(values, shape)
80 }
81 tch::Kind::Double => {
82 let values = Vec::<f64>::try_from(&tensor).unwrap();
83 TensorData::new(values, shape)
84 }
85 tch::Kind::BFloat16 => {
86 let values = Vec::<bf16>::try_from(&tensor).unwrap();
87 TensorData::new(values, shape)
88 }
89 _ => panic!("Not a valid float kind"),
90 })
91 }
92
93 fn float_device(tensor: &TchTensor) -> LibTorchDevice {
94 tensor.tensor.device().into()
95 }
96
97 fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
98 TchOps::to_device(tensor, device)
99 }
100
101 fn float_empty(
102 shape: Shape,
103 device: &<LibTorch<E> as Backend>::Device,
104 dtype: FloatDType,
105 ) -> TchTensor {
106 let tensor = tch::Tensor::empty(
107 TchShape::from(shape).dims,
108 (dtype.into_kind(), (*device).into()),
109 );
110
111 TchTensor::new(tensor)
112 }
113
114 fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
115 TchOps::add(lhs, rhs)
116 }
117
118 fn float_add_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
119 let rhs: f64 = rhs.elem();
120
121 lhs.unary_ops(
122 |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
123 |tensor| tensor.f_add_scalar(rhs).unwrap(),
124 )
125 }
126
127 fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
128 TchOps::sub(lhs, rhs)
129 }
130
131 fn float_sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
132 let rhs: f64 = rhs.elem();
133
134 lhs.unary_ops(
135 |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
136 |tensor| tensor.f_sub_scalar(rhs).unwrap(),
137 )
138 }
139
140 fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
141 TchOps::mul(lhs, rhs)
142 }
143
144 fn float_mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
145 let rhs: f64 = rhs.elem();
146
147 lhs.unary_ops(
148 |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
149 |tensor| tensor.f_mul_scalar(rhs).unwrap(),
150 )
151 }
152
153 fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
154 TchOps::div(lhs, rhs)
155 }
156
157 fn float_div_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
158 let rhs: f64 = rhs.elem();
159
160 lhs.unary_ops(
161 |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
162 |tensor| tensor.f_div_scalar(rhs).unwrap(),
163 )
164 }
165
166 fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
167 TchOps::remainder(lhs, rhs)
168 }
169
170 fn float_remainder_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
171 let rhs: f64 = rhs.elem();
172
173 lhs.unary_ops(
174 |tensor| tensor.f_remainder(rhs).unwrap(),
175 |tensor| tensor.f_remainder(rhs).unwrap(),
176 )
177 }
178
179 fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
180 let tensor = lhs.tensor.matmul(&rhs.tensor);
181 TchTensor::new(tensor)
182 }
183
184 fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {
185 let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);
186 TchTensor::new(tensor)
187 }
188
189 fn float_neg(tensor: TchTensor) -> TchTensor {
190 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
191 }
192
193 fn float_recip(tensor: TchTensor) -> TchTensor {
194 TchTensor::new(tensor.tensor.reciprocal())
195 }
196
197 fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
198 TchOps::swap_dims(tensor, dim1, dim2)
199 }
200
201 fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
202 TchOps::reshape(tensor, shape)
203 }
204
205 fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
206 TchOps::gather(dim, tensor, indices)
207 }
208
209 fn float_scatter_add(
210 dim: usize,
211 tensor: TchTensor,
212 indices: TchTensor,
213 value: TchTensor,
214 ) -> TchTensor {
215 TchOps::scatter(dim, tensor, indices, value)
216 }
217
218 fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
219 TchOps::index_select_dim(tensor, dim, indices)
220 }
221
222 fn float_select_add(
223 tensor: TchTensor,
224 dim: usize,
225 indices: TchTensor,
226 value: TchTensor,
227 ) -> TchTensor {
228 TchOps::select_assign(tensor, dim, indices, value)
229 }
230
231 fn float_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
232 TchOps::slice_with_steps(tensor, slices)
233 }
234
235 fn float_slice_assign(
236 tensor: TchTensor,
237 slices: &[burn_backend::Slice],
238 value: TchTensor,
239 ) -> TchTensor {
240 TchOps::slice_assign(tensor, slices, value)
241 }
242
243 fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
244 let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
245
246 TchTensor::new(output)
247 }
248
249 fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: E) -> TchTensor {
250 let value: f64 = value.elem();
251
252 tensor.unary_ops(
253 |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
254 |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
255 )
256 }
257
258 fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
259 TchOps::equal(lhs, rhs)
260 }
261
262 fn float_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
263 TchOps::equal_elem(lhs, rhs.elem::<f64>())
264 }
265
266 fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
267 TchOps::greater(lhs, rhs)
268 }
269
270 fn float_greater_elem(lhs: TchTensor, rhs: E) -> TchTensor {
271 TchOps::greater_elem(lhs, rhs.elem::<f64>())
272 }
273
274 fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
275 TchOps::greater_equal(lhs, rhs)
276 }
277
278 fn float_greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
279 TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
280 }
281
282 fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
283 TchOps::lower(lhs, rhs)
284 }
285
286 fn float_lower_elem(lhs: TchTensor, rhs: E) -> TchTensor {
287 TchOps::lower_elem(lhs, rhs.elem::<f64>())
288 }
289
290 fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
291 TchOps::lower_equal(lhs, rhs)
292 }
293
294 fn float_lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
295 TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
296 }
297
298 fn float_mean(tensor: TchTensor) -> TchTensor {
299 TchOps::mean(tensor)
300 }
301
302 fn float_sum(tensor: TchTensor) -> TchTensor {
303 TchOps::sum(tensor)
304 }
305
306 fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
307 TchOps::sum_dim(tensor, dim)
308 }
309
310 fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
311 TchOps::mean_dim(tensor, dim)
312 }
313
314 fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
315 TchOps::cumsum(tensor, dim)
316 }
317
318 fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
319 TchOps::cumprod(tensor, dim)
320 }
321
322 fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
323 TchOps::cummin(tensor, dim)
324 }
325
326 fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
327 TchOps::cummax(tensor, dim)
328 }
329
330 fn float_prod(tensor: TchTensor) -> TchTensor {
331 TchOps::prod(tensor)
332 }
333
334 fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
335 TchOps::prod_dim(tensor, dim)
336 }
337
338 fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
339 TchOps::argmax(tensor, dim)
340 }
341
342 fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
343 TchOps::argmin(tensor, dim)
344 }
345
346 fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
347 TchOps::max_dim(tensor, dim)
348 }
349
350 fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
351 TchOps::max_dim_with_indices(tensor, dim)
352 }
353
354 fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
355 TchOps::min_dim(tensor, dim)
356 }
357
358 fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
359 TchOps::min_dim_with_indices(tensor, dim)
360 }
361
362 fn float_exp(tensor: TchTensor) -> TchTensor {
363 tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
364 }
365
366 fn float_log(tensor: TchTensor) -> TchTensor {
367 tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
368 }
369
370 fn float_log1p(tensor: TchTensor) -> TchTensor {
371 tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
372 }
373
374 fn float_powf_scalar_impl(tensor: TchTensor, value: f32) -> TchTensor {
375 tensor.unary_ops(
376 |mut tensor| tensor.f_pow_(value as f64).unwrap(),
377 |tensor| tensor.pow_tensor_scalar(value as f64),
378 )
379 }
380
381 fn float_sqrt(tensor: TchTensor) -> TchTensor {
382 tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
383 }
384
385 fn float_abs(tensor: TchTensor) -> TchTensor {
386 tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
387 }
388
389 fn float_cos(tensor: TchTensor) -> TchTensor {
390 tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
391 }
392
393 fn float_cosh(tensor: TchTensor) -> TchTensor {
394 tensor.unary_ops(|mut tensor| tensor.cosh_(), |tensor| tensor.cosh())
395 }
396
397 fn float_sin(tensor: TchTensor) -> TchTensor {
398 tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
399 }
400
401 fn float_sinh(tensor: TchTensor) -> TchTensor {
402 tensor.unary_ops(|mut tensor| tensor.sinh_(), |tensor| tensor.sinh())
403 }
404
405 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406 tensor.unary_ops(|mut tensor| tensor.tan_(), |tensor| tensor.tan())
407 }
408
409 fn float_tanh(tensor: TchTensor) -> TchTensor {
410 tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
411 }
412
413 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
414 tensor.unary_ops(|mut tensor| tensor.acos_(), |tensor| tensor.acos())
415 }
416
417 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
418 tensor.unary_ops(|mut tensor| tensor.acosh_(), |tensor| tensor.acosh())
419 }
420
421 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
422 tensor.unary_ops(|mut tensor| tensor.asin_(), |tensor| tensor.asin())
423 }
424
425 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
426 tensor.unary_ops(|mut tensor| tensor.asinh_(), |tensor| tensor.asinh())
427 }
428
429 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
430 tensor.unary_ops(|mut tensor| tensor.atan_(), |tensor| tensor.atan())
431 }
432
433 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
434 tensor.unary_ops(|mut tensor| tensor.atanh_(), |tensor| tensor.atanh())
435 }
436
437 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
438 TchOps::atan2(lhs, rhs)
439 }
440
441 fn float_round(tensor: TchTensor) -> TchTensor {
442 tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
443 }
444
445 fn float_floor(tensor: TchTensor) -> TchTensor {
446 tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
447 }
448
449 fn float_ceil(tensor: TchTensor) -> TchTensor {
450 tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
451 }
452
453 fn float_trunc(tensor: TchTensor) -> TchTensor {
454 tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())
455 }
456
457 fn float_erf(tensor: TchTensor) -> TchTensor {
458 tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
459 }
460
461 fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
462 TchOps::cat(tensors, dim)
463 }
464
465 fn float_clamp_min(tensor: TchTensor, min: E) -> TchTensor {
466 TchOps::clamp_min(tensor, min.elem::<f64>())
467 }
468
469 fn float_clamp_max(tensor: TchTensor, max: <LibTorch<E> as Backend>::FloatElem) -> TchTensor {
470 TchOps::clamp_max(tensor, max.elem::<f64>())
471 }
472
473 fn float_clamp(
474 tensor: TchTensor,
475 min: <LibTorch<E> as Backend>::FloatElem,
476 max: <LibTorch<E> as Backend>::FloatElem,
477 ) -> TchTensor {
478 TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
479 }
480
481 fn float_into_int(tensor: TchTensor) -> TchTensor {
482 let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
483 TchTensor::new(tensor)
484 }
485
486 fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
487 TchOps::powf(lhs, rhs)
488 }
489
490 fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
491 TchOps::permute(tensor, axes)
492 }
493
494 fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
495 TchOps::flip(tensor, axes)
496 }
497
498 fn float_sign(tensor: TchTensor) -> TchTensor {
499 TchOps::sign(tensor)
500 }
501
502 fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
503 TchOps::expand(tensor, shape)
504 }
505
506 fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
507 TchOps::sort(tensor, dim, descending)
508 }
509
510 fn float_sort_with_indices(
511 tensor: TchTensor,
512 dim: usize,
513 descending: bool,
514 ) -> (TchTensor, TchTensor) {
515 TchOps::sort_with_indices(tensor, dim, descending)
516 }
517
518 fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {
519 TchOps::argsort(tensor, dim, descending)
520 }
521
522 fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
523 let kind = dtype.into_kind();
528
529 if tensor.tensor.kind() == kind {
530 tensor
531 } else {
532 TchTensor::new(tensor.tensor.to_kind(kind))
533 }
534 }
535
536 fn float_unfold(
537 tensor: FloatTensor<Self>,
538 dim: usize,
539 size: usize,
540 step: usize,
541 ) -> FloatTensor<Self> {
542 TchOps::unfold(tensor, dim, size, step)
543 }
544
545 fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
546 TchTensor::new(tensor.tensor.isnan())
547 }
548
549 fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
550 TchTensor::new(tensor.tensor.isinf())
551 }
552}