1use super::TchOps;
2use crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};
3use burn_tensor::ops::{BoolTensor, FloatTensor};
4use burn_tensor::{
5 DType, Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata,
6 backend::Backend,
7 ops::{FloatTensorOps, IntTensor},
8};
9use half::{bf16, f16};
10
11impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
12 fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
13 match data.dtype {
14 DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),
15 DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),
16 DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),
17 DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),
18 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
19 }
20 }
21
22 fn float_random(
23 shape: Shape,
24 distribution: Distribution,
25 device: &LibTorchDevice,
26 ) -> TchTensor {
27 match distribution {
28 Distribution::Default => {
29 let mut tensor = TchTensor::empty::<E>(shape, *device);
30 tensor
31 .mut_ops(|tensor| tensor.rand_like_out(tensor))
32 .unwrap()
33 }
34 Distribution::Bernoulli(prob) => {
35 let mut tensor = TchTensor::empty::<E>(shape, *device);
36 tensor
37 .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
38 .unwrap()
39 }
40 Distribution::Uniform(from, to) => {
41 let mut tensor = TchTensor::empty::<E>(shape, *device);
42 tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
43 }
44 Distribution::Normal(mean, std) => {
45 let mut tensor = TchTensor::empty::<E>(shape, *device);
46 tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
47 }
48 }
49 }
50
51 fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
52 TchOps::repeat_dim(tensor, dim, times)
53 }
54
55 fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
56 let shape = TchShape::from(shape);
57 let device: tch::Device = (*device).into();
58
59 TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))
60 }
61
62 fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {
63 let shape = TchShape::from(shape);
64 let device: tch::Device = (*device).into();
65
66 TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))
67 }
68
69 async fn float_into_data(tensor: TchTensor) -> TensorData {
70 let shape = tensor.shape();
71 let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
72 match tensor.tensor.kind() {
73 tch::Kind::Half => {
74 let values: Vec<f16> = tensor.tensor.try_into().unwrap();
75 TensorData::new(values, shape)
76 }
77 tch::Kind::Float => {
78 let values: Vec<f32> = tensor.tensor.try_into().unwrap();
79 TensorData::new(values, shape)
80 }
81 tch::Kind::Double => {
82 let values: Vec<f64> = tensor.tensor.try_into().unwrap();
83 TensorData::new(values, shape)
84 }
85 tch::Kind::BFloat16 => {
86 let values: Vec<bf16> = tensor.tensor.try_into().unwrap();
87 TensorData::new(values, shape)
88 }
89 _ => panic!("Not a valid float kind"),
90 }
91 }
92
93 fn float_device(tensor: &TchTensor) -> LibTorchDevice {
94 tensor.tensor.device().into()
95 }
96
97 fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
98 TchOps::to_device(tensor, device)
99 }
100
101 fn float_empty(
102 shape: Shape,
103 device: &<LibTorch<E> as Backend>::Device,
104 dtype: FloatDType,
105 ) -> TchTensor {
106 let tensor = tch::Tensor::empty(
107 TchShape::from(shape).dims,
108 (dtype.into_kind(), (*device).into()),
109 );
110
111 TchTensor::new(tensor)
112 }
113
114 fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
115 TchOps::add(lhs, rhs)
116 }
117
118 fn float_add_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
119 let rhs: f64 = rhs.elem();
120
121 lhs.unary_ops(
122 |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
123 |tensor| tensor.f_add_scalar(rhs).unwrap(),
124 )
125 }
126
127 fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
128 TchOps::sub(lhs, rhs)
129 }
130
131 fn float_sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
132 let rhs: f64 = rhs.elem();
133
134 lhs.unary_ops(
135 |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
136 |tensor| tensor.f_sub_scalar(rhs).unwrap(),
137 )
138 }
139
140 fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
141 TchOps::mul(lhs, rhs)
142 }
143
144 fn float_mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
145 let rhs: f64 = rhs.elem();
146
147 lhs.unary_ops(
148 |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
149 |tensor| tensor.f_mul_scalar(rhs).unwrap(),
150 )
151 }
152
153 fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
154 TchOps::div(lhs, rhs)
155 }
156
157 fn float_div_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
158 let rhs: f64 = rhs.elem();
159
160 lhs.unary_ops(
161 |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
162 |tensor| tensor.f_div_scalar(rhs).unwrap(),
163 )
164 }
165
166 fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
167 TchOps::remainder(lhs, rhs)
168 }
169
170 fn float_remainder_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
171 let rhs: f64 = rhs.elem();
172
173 lhs.unary_ops(
174 |tensor| tensor.f_remainder(rhs).unwrap(),
175 |tensor| tensor.f_remainder(rhs).unwrap(),
176 )
177 }
178
179 fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
180 let tensor = lhs.tensor.matmul(&rhs.tensor);
181 TchTensor::new(tensor)
182 }
183
184 fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {
185 let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);
186 TchTensor::new(tensor)
187 }
188
189 fn float_neg(tensor: TchTensor) -> TchTensor {
190 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
191 }
192
193 fn float_recip(tensor: TchTensor) -> TchTensor {
194 TchTensor::new(tensor.tensor.reciprocal())
195 }
196
197 fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
198 TchOps::swap_dims(tensor, dim1, dim2)
199 }
200
201 fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
202 TchOps::reshape(tensor, shape)
203 }
204
205 fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
206 TchOps::gather(dim, tensor, indices)
207 }
208
209 fn float_scatter(
210 dim: usize,
211 tensor: TchTensor,
212 indices: TchTensor,
213 value: TchTensor,
214 ) -> TchTensor {
215 TchOps::scatter(dim, tensor, indices, value)
216 }
217
218 fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
219 TchOps::index_select_dim(tensor, dim, indices)
220 }
221
222 fn float_select_assign(
223 tensor: TchTensor,
224 dim: usize,
225 indices: TchTensor,
226 value: TchTensor,
227 ) -> TchTensor {
228 TchOps::select_assign(tensor, dim, indices, value)
229 }
230
231 fn float_slice(tensor: TchTensor, slices: &[burn_tensor::Slice]) -> TchTensor {
232 TchOps::slice_with_steps(tensor, slices)
233 }
234
235 fn float_slice_assign(
236 tensor: TchTensor,
237 slices: &[burn_tensor::Slice],
238 value: TchTensor,
239 ) -> TchTensor {
240 TchOps::slice_assign(tensor, slices, value)
241 }
242
243 fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
244 let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
245
246 TchTensor::new(output)
247 }
248
249 fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: E) -> TchTensor {
250 let value: f64 = value.elem();
251
252 tensor.unary_ops(
253 |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
254 |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
255 )
256 }
257
258 fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
259 TchOps::equal(lhs, rhs)
260 }
261
262 fn float_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
263 TchOps::equal_elem(lhs, rhs.elem::<f64>())
264 }
265
266 fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
267 TchOps::greater(lhs, rhs)
268 }
269
270 fn float_greater_elem(lhs: TchTensor, rhs: E) -> TchTensor {
271 TchOps::greater_elem(lhs, rhs.elem::<f64>())
272 }
273
274 fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
275 TchOps::greater_equal(lhs, rhs)
276 }
277
278 fn float_greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
279 TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
280 }
281
282 fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
283 TchOps::lower(lhs, rhs)
284 }
285
286 fn float_lower_elem(lhs: TchTensor, rhs: E) -> TchTensor {
287 TchOps::lower_elem(lhs, rhs.elem::<f64>())
288 }
289
290 fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
291 TchOps::lower_equal(lhs, rhs)
292 }
293
294 fn float_lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
295 TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
296 }
297
298 fn float_mean(tensor: TchTensor) -> TchTensor {
299 TchOps::mean(tensor)
300 }
301
302 fn float_sum(tensor: TchTensor) -> TchTensor {
303 TchOps::sum(tensor)
304 }
305
306 fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
307 TchOps::sum_dim(tensor, dim)
308 }
309
310 fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
311 TchOps::mean_dim(tensor, dim)
312 }
313
314 fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
315 TchOps::cumsum(tensor, dim)
316 }
317
318 fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {
319 TchOps::cumprod(tensor, dim)
320 }
321
322 fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {
323 TchOps::cummin(tensor, dim)
324 }
325
326 fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {
327 TchOps::cummax(tensor, dim)
328 }
329
330 fn float_prod(tensor: TchTensor) -> TchTensor {
331 TchOps::prod(tensor)
332 }
333
334 fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
335 TchOps::prod_dim(tensor, dim)
336 }
337
338 fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
339 TchOps::argmax(tensor, dim)
340 }
341
342 fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
343 TchOps::argmin(tensor, dim)
344 }
345
346 fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
347 TchOps::max_dim(tensor, dim)
348 }
349
350 fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
351 TchOps::max_dim_with_indices(tensor, dim)
352 }
353
354 fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
355 TchOps::min_dim(tensor, dim)
356 }
357
358 fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
359 TchOps::min_dim_with_indices(tensor, dim)
360 }
361
362 fn float_exp(tensor: TchTensor) -> TchTensor {
363 tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
364 }
365
366 fn float_log(tensor: TchTensor) -> TchTensor {
367 tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
368 }
369
370 fn float_log1p(tensor: TchTensor) -> TchTensor {
371 tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
372 }
373
374 fn float_powf_scalar_impl(tensor: TchTensor, value: f32) -> TchTensor {
375 tensor.unary_ops(
376 |mut tensor| tensor.f_pow_(value as f64).unwrap(),
377 |tensor| tensor.pow_tensor_scalar(value as f64),
378 )
379 }
380
381 fn float_sqrt(tensor: TchTensor) -> TchTensor {
382 tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
383 }
384
385 fn float_abs(tensor: TchTensor) -> TchTensor {
386 tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
387 }
388
389 fn float_cos(tensor: TchTensor) -> TchTensor {
390 tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
391 }
392
393 fn float_sin(tensor: TchTensor) -> TchTensor {
394 tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
395 }
396
397 fn float_tanh(tensor: TchTensor) -> TchTensor {
398 tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
399 }
400
401 fn float_round(tensor: TchTensor) -> TchTensor {
402 tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
403 }
404
405 fn float_floor(tensor: TchTensor) -> TchTensor {
406 tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
407 }
408
409 fn float_ceil(tensor: TchTensor) -> TchTensor {
410 tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
411 }
412
413 fn float_trunc(tensor: TchTensor) -> TchTensor {
414 tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())
415 }
416
417 fn float_erf(tensor: TchTensor) -> TchTensor {
418 tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
419 }
420
421 fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
422 TchOps::cat(tensors, dim)
423 }
424
425 fn float_clamp_min(tensor: TchTensor, min: E) -> TchTensor {
426 TchOps::clamp_min(tensor, min.elem::<f64>())
427 }
428
429 fn float_clamp_max(tensor: TchTensor, max: <LibTorch<E> as Backend>::FloatElem) -> TchTensor {
430 TchOps::clamp_max(tensor, max.elem::<f64>())
431 }
432
433 fn float_clamp(
434 tensor: TchTensor,
435 min: <LibTorch<E> as Backend>::FloatElem,
436 max: <LibTorch<E> as Backend>::FloatElem,
437 ) -> TchTensor {
438 TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
439 }
440
441 fn float_into_int(tensor: TchTensor) -> TchTensor {
442 let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
443 TchTensor::new(tensor)
444 }
445
446 fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
447 TchOps::powf(lhs, rhs)
448 }
449
450 fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
451 TchOps::permute(tensor, axes)
452 }
453
454 fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
455 TchOps::flip(tensor, axes)
456 }
457
458 fn float_sign(tensor: TchTensor) -> TchTensor {
459 TchOps::sign(tensor)
460 }
461
462 fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
463 TchOps::expand(tensor, shape)
464 }
465
466 fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
467 TchOps::sort(tensor, dim, descending)
468 }
469
470 fn float_sort_with_indices(
471 tensor: TchTensor,
472 dim: usize,
473 descending: bool,
474 ) -> (TchTensor, TchTensor) {
475 TchOps::sort_with_indices(tensor, dim, descending)
476 }
477
478 fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {
479 TchOps::argsort(tensor, dim, descending)
480 }
481
482 fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
483 let kind = match dtype {
488 FloatDType::F64 => tch::Kind::Double,
489 FloatDType::F32 => tch::Kind::Float,
490 FloatDType::Flex32 => tch::Kind::Float,
491 FloatDType::F16 => tch::Kind::Half,
492 FloatDType::BF16 => tch::Kind::BFloat16,
493 };
494
495 if tensor.tensor.kind() == kind {
496 tensor
497 } else {
498 TchTensor::new(tensor.tensor.to_kind(kind))
499 }
500 }
501
502 fn float_unfold(
503 tensor: FloatTensor<Self>,
504 dim: usize,
505 size: usize,
506 step: usize,
507 ) -> FloatTensor<Self> {
508 TchOps::unfold(tensor, dim, size, step)
509 }
510
511 fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
512 TchTensor::new(tensor.tensor.isnan())
513 }
514
515 fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
516 TchTensor::new(tensor.tensor.isinf())
517 }
518}