1use super::TchOps;
2use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_backend::backend::ExecutionError;
4use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
5use burn_backend::{BoolDType, IntDType, Scalar, bf16, f16};
6use burn_backend::{
7 DType, Distribution, FloatDType, Shape, TensorData, TensorMetadata, ops::FloatTensorOps,
8};
9
10impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
11 fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
12 match data.dtype {
13 DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
14 DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),
15 DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),
16 DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),
17 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
18 }
19 }
20
21 fn float_random(
22 shape: Shape,
23 distribution: Distribution,
24 device: &LibTorchDevice,
25 dtype: FloatDType,
26 ) -> TchTensor {
27 match distribution {
28 Distribution::Default => {
29 let mut tensor = TchTensor::empty(shape, *device, dtype.into());
30 tensor
31 .mut_ops(|tensor| tensor.rand_like_out(tensor))
32 .unwrap()
33 }
34 Distribution::Bernoulli(prob) => {
35 let mut tensor = TchTensor::empty(shape, *device, dtype.into());
36 tensor
37 .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
38 .unwrap()
39 }
40 Distribution::Uniform(from, to) => {
41 let mut tensor = TchTensor::empty(shape, *device, dtype.into());
42 tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
43 }
44 Distribution::Normal(mean, std) => {
45 let mut tensor = TchTensor::empty(shape, *device, dtype.into());
46 tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
47 }
48 }
49 }
50
51 fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
52 TchOps::repeat_dim(tensor, dim, times)
53 }
54
55 fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
56 let shape = TchShape::from(shape);
57 let device: tch::Device = (*device).into();
58
59 TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
60 }
61
62 fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
63 let shape = TchShape::from(shape);
64 let device: tch::Device = (*device).into();
65
66 TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
67 }
68
69 async fn float_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {
70 let shape = tensor.shape();
71 let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
72 Ok(match tensor.tensor.kind() {
73 tch::Kind::Half => {
74 let values = Vec::<f16>::try_from(&tensor).unwrap();
75 TensorData::new(values, shape)
76 }
77 tch::Kind::Float => {
78 let values = Vec::<f32>::try_from(&tensor).unwrap();
79 TensorData::new(values, shape)
80 }
81 tch::Kind::Double => {
82 let values = Vec::<f64>::try_from(&tensor).unwrap();
83 TensorData::new(values, shape)
84 }
85 tch::Kind::BFloat16 => {
86 let values = Vec::<bf16>::try_from(&tensor).unwrap();
87 TensorData::new(values, shape)
88 }
89 _ => panic!("Not a valid float kind"),
90 })
91 }
92
93 fn float_device(tensor: &TchTensor) -> LibTorchDevice {
94 tensor.tensor.device().into()
95 }
96
97 fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
98 TchOps::to_device(tensor, device)
99 }
100
101 fn float_empty(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
102 let tensor = tch::Tensor::empty(
103 TchShape::from(shape).dims,
104 (dtype.into_kind(), (*device).into()),
105 );
106
107 TchTensor::new(tensor)
108 }
109
110 fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
111 TchOps::add(lhs, rhs)
112 }
113
114 fn float_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
115 let rhs: f64 = rhs.elem();
116
117 lhs.unary_ops(
118 |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
119 |tensor| tensor.f_add_scalar(rhs).unwrap(),
120 )
121 }
122
123 fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
124 TchOps::sub(lhs, rhs)
125 }
126
127 fn float_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
128 let rhs: f64 = rhs.elem();
129
130 lhs.unary_ops(
131 |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
132 |tensor| tensor.f_sub_scalar(rhs).unwrap(),
133 )
134 }
135
136 fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
137 TchOps::mul(lhs, rhs)
138 }
139
140 fn float_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
141 let rhs: f64 = rhs.elem();
142
143 lhs.unary_ops(
144 |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
145 |tensor| tensor.f_mul_scalar(rhs).unwrap(),
146 )
147 }
148
149 fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
150 TchOps::div(lhs, rhs)
151 }
152
153 fn float_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
154 let rhs: f64 = rhs.elem();
155
156 lhs.unary_ops(
157 |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
158 |tensor| tensor.f_div_scalar(rhs).unwrap(),
159 )
160 }
161
162 fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
163 TchOps::remainder(lhs, rhs)
164 }
165
166 fn float_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {
167 let rhs: f64 = rhs.elem();
168
169 lhs.unary_ops(
170 |tensor| tensor.f_remainder(rhs).unwrap(),
171 |tensor| tensor.f_remainder(rhs).unwrap(),
172 )
173 }
174
175 fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
176 let tensor = lhs.tensor.matmul(&rhs.tensor);
177 TchTensor::new(tensor)
178 }
179
180 fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {
181 let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);
182 TchTensor::new(tensor)
183 }
184
185 fn float_recip(tensor: TchTensor) -> TchTensor {
186 TchTensor::new(tensor.tensor.reciprocal())
187 }
188
189 fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
190 TchOps::swap_dims(tensor, dim1, dim2)
191 }
192
193 fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
194 TchOps::reshape(tensor, shape)
195 }
196
197 fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
198 TchOps::gather(dim, tensor, indices)
199 }
200
201 fn float_scatter_add(
202 dim: usize,
203 tensor: TchTensor,
204 indices: TchTensor,
205 value: TchTensor,
206 ) -> TchTensor {
207 TchOps::scatter(dim, tensor, indices, value)
208 }
209
210 fn float_scatter_nd(
211 data: TchTensor,
212 indices: TchTensor,
213 values: TchTensor,
214 reduction: burn_backend::tensor::IndexingUpdateOp,
215 ) -> TchTensor {
216 TchOps::scatter_nd(data, indices, values, reduction)
217 }
218
219 fn float_gather_nd(data: TchTensor, indices: TchTensor) -> TchTensor {
220 TchOps::gather_nd(data, indices)
221 }
222
223 fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
224 TchOps::index_select_dim(tensor, dim, indices)
225 }
226
227 fn float_select_add(
228 tensor: TchTensor,
229 dim: usize,
230 indices: TchTensor,
231 value: TchTensor,
232 ) -> TchTensor {
233 TchOps::select_assign(tensor, dim, indices, value)
234 }
235
236 fn float_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {
237 TchOps::slice_with_steps(tensor, slices)
238 }
239
240 fn float_slice_assign(
241 tensor: TchTensor,
242 slices: &[burn_backend::Slice],
243 value: TchTensor,
244 ) -> TchTensor {
245 TchOps::slice_assign(tensor, slices, value)
246 }
247
248 fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
249 let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
250
251 TchTensor::new(output)
252 }
253
254 fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {
255 let value: f64 = value.elem();
256
257 tensor.unary_ops(
258 |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
259 |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
260 )
261 }
262
263 fn float_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
264 TchOps::equal(lhs, rhs)
265 }
266
267 fn float_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
268 TchOps::equal_elem(lhs, rhs.elem::<f64>())
269 }
270
271 fn float_greater(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
272 TchOps::greater(lhs, rhs)
273 }
274
275 fn float_greater_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
276 TchOps::greater_elem(lhs, rhs.elem::<f64>())
277 }
278
279 fn float_greater_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
280 TchOps::greater_equal(lhs, rhs)
281 }
282
283 fn float_greater_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
284 TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
285 }
286
287 fn float_lower(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
288 TchOps::lower(lhs, rhs)
289 }
290
291 fn float_lower_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
292 TchOps::lower_elem(lhs, rhs.elem::<f64>())
293 }
294
295 fn float_lower_equal(lhs: TchTensor, rhs: TchTensor, _out_dtype: BoolDType) -> TchTensor {
296 TchOps::lower_equal(lhs, rhs)
297 }
298
299 fn float_lower_equal_elem(lhs: TchTensor, rhs: Scalar, _out_dtype: BoolDType) -> TchTensor {
300 TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
301 }
302
303 fn float_mean(tensor: TchTensor) -> TchTensor {
304 TchOps::mean(tensor)
305 }
306
307 fn float_sum(tensor: TchTensor) -> TchTensor {
308 TchOps::sum(tensor)
309 }
310
311 fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
312 TchOps::sum_dim(tensor, dim)
313 }
314
315 fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
316 TchOps::mean_dim(tensor, dim)
317 }
318
319 fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
320 TchOps::cumsum(tensor, dim)
321 }
322
323 fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
324 TchOps::cumprod(tensor, dim)
325 }
326
327 fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
328 TchOps::cummin(tensor, dim)
329 }
330
331 fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
332 TchOps::cummax(tensor, dim)
333 }
334
335 fn float_prod(tensor: TchTensor) -> TchTensor {
336 TchOps::prod(tensor)
337 }
338
339 fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
340 TchOps::prod_dim(tensor, dim)
341 }
342
343 fn float_argmax(tensor: TchTensor, dim: usize, _indices_dtype: IntDType) -> TchTensor {
344 TchOps::argmax(tensor, dim)
345 }
346
347 fn float_argtopk(
348 tensor: TchTensor,
349 dim: usize,
350 k: usize,
351 _indices_dtype: IntDType,
352 ) -> TchTensor {
353 TchOps::argtopk(tensor, dim, k)
354 }
355
356 fn float_topk(tensor: TchTensor, dim: usize, k: usize) -> TchTensor {
357 TchOps::topk(tensor, dim, k)
358 }
359
360 fn float_argmin(tensor: TchTensor, dim: usize, _out_dtype: IntDType) -> TchTensor {
361 TchOps::argmin(tensor, dim)
362 }
363
364 fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
365 TchOps::max_dim(tensor, dim)
366 }
367
368 fn float_max_dim_with_indices(
369 tensor: TchTensor,
370 dim: usize,
371 _indices_dtype: IntDType,
372 ) -> (TchTensor, TchTensor) {
373 TchOps::max_dim_with_indices(tensor, dim)
374 }
375
376 fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
377 TchOps::min_dim(tensor, dim)
378 }
379
380 fn float_min_dim_with_indices(
381 tensor: TchTensor,
382 dim: usize,
383 _indices_dtype: IntDType,
384 ) -> (TchTensor, TchTensor) {
385 TchOps::min_dim_with_indices(tensor, dim)
386 }
387
388 fn float_exp(tensor: TchTensor) -> TchTensor {
389 tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
390 }
391
392 fn float_log(tensor: TchTensor) -> TchTensor {
393 tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
394 }
395
396 fn float_log1p(tensor: TchTensor) -> TchTensor {
397 tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
398 }
399
400 fn float_powf_scalar_impl(tensor: TchTensor, value: Scalar) -> TchTensor {
401 tensor.unary_ops(
402 |mut tensor| tensor.f_pow_(value.elem::<f64>()).unwrap(),
403 |tensor| tensor.pow_tensor_scalar(value.elem::<f64>()),
404 )
405 }
406
407 fn float_sqrt(tensor: TchTensor) -> TchTensor {
408 tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
409 }
410
411 fn float_abs(tensor: TchTensor) -> TchTensor {
412 tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
413 }
414
415 fn float_cos(tensor: TchTensor) -> TchTensor {
416 tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
417 }
418
419 fn float_cosh(tensor: TchTensor) -> TchTensor {
420 tensor.unary_ops(|mut tensor| tensor.cosh_(), |tensor| tensor.cosh())
421 }
422
423 fn float_sin(tensor: TchTensor) -> TchTensor {
424 tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
425 }
426
427 fn float_sinh(tensor: TchTensor) -> TchTensor {
428 tensor.unary_ops(|mut tensor| tensor.sinh_(), |tensor| tensor.sinh())
429 }
430
431 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
432 tensor.unary_ops(|mut tensor| tensor.tan_(), |tensor| tensor.tan())
433 }
434
435 fn float_tanh(tensor: TchTensor) -> TchTensor {
436 tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
437 }
438
439 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
440 tensor.unary_ops(|mut tensor| tensor.acos_(), |tensor| tensor.acos())
441 }
442
443 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
444 tensor.unary_ops(|mut tensor| tensor.acosh_(), |tensor| tensor.acosh())
445 }
446
447 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
448 tensor.unary_ops(|mut tensor| tensor.asin_(), |tensor| tensor.asin())
449 }
450
451 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
452 tensor.unary_ops(|mut tensor| tensor.asinh_(), |tensor| tensor.asinh())
453 }
454
455 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
456 tensor.unary_ops(|mut tensor| tensor.atan_(), |tensor| tensor.atan())
457 }
458
459 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
460 tensor.unary_ops(|mut tensor| tensor.atanh_(), |tensor| tensor.atanh())
461 }
462
463 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
464 TchOps::atan2(lhs, rhs)
465 }
466
467 fn float_round(tensor: TchTensor) -> TchTensor {
468 tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
469 }
470
471 fn float_floor(tensor: TchTensor) -> TchTensor {
472 tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
473 }
474
475 fn float_ceil(tensor: TchTensor) -> TchTensor {
476 tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
477 }
478
479 fn float_trunc(tensor: TchTensor) -> TchTensor {
480 tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())
481 }
482
483 fn float_erf(tensor: TchTensor) -> TchTensor {
484 tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
485 }
486
487 fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
488 TchOps::cat(tensors, dim)
489 }
490
491 fn float_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {
492 TchOps::clamp_min(tensor, min.elem::<f64>())
493 }
494
495 fn float_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {
496 TchOps::clamp_max(tensor, max.elem::<f64>())
497 }
498
499 fn float_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {
500 TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
501 }
502
503 fn float_into_int(tensor: TchTensor, _out_dtype: IntDType) -> TchTensor {
504 let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
505 TchTensor::new(tensor)
506 }
507
508 fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
509 TchOps::pow(lhs, rhs)
510 }
511
512 fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
513 TchOps::permute(tensor, axes)
514 }
515
516 fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
517 TchOps::flip(tensor, axes)
518 }
519
520 fn float_sign(tensor: TchTensor) -> TchTensor {
521 TchOps::sign(tensor)
522 }
523
524 fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
525 TchOps::expand(tensor, shape)
526 }
527
528 fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
529 TchOps::sort(tensor, dim, descending)
530 }
531
532 fn float_sort_with_indices(
533 tensor: TchTensor,
534 dim: usize,
535 descending: bool,
536 _indices_dtype: IntDType,
537 ) -> (TchTensor, TchTensor) {
538 TchOps::sort_with_indices(tensor, dim, descending)
539 }
540
541 fn float_argsort(
542 tensor: TchTensor,
543 dim: usize,
544 descending: bool,
545 _out_dtype: IntDType,
546 ) -> IntTensor<Self> {
547 TchOps::argsort(tensor, dim, descending)
548 }
549
550 fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
551 let kind = dtype.into_kind();
556
557 if tensor.tensor.kind() == kind {
558 tensor
559 } else {
560 TchTensor::new(tensor.tensor.to_kind(kind))
561 }
562 }
563
564 fn float_unfold(
565 tensor: FloatTensor<Self>,
566 dim: usize,
567 size: usize,
568 step: usize,
569 ) -> FloatTensor<Self> {
570 TchOps::unfold(tensor, dim, size, step)
571 }
572
573 fn float_is_nan(tensor: FloatTensor<Self>, _out_dtype: BoolDType) -> BoolTensor<Self> {
574 TchTensor::new(tensor.tensor.isnan())
575 }
576
577 fn float_is_inf(tensor: FloatTensor<Self>, _out_dtype: BoolDType) -> BoolTensor<Self> {
578 TchTensor::new(tensor.tensor.isinf())
579 }
580}