1use super::TchOps;
2use crate::{element::TchElement, LibTorch, LibTorchDevice, QuantElement, TchShape, TchTensor};
3use burn_tensor::{
4 backend::Backend,
5 ops::{FloatTensorOps, IntTensor},
6 Distribution, ElementConversion, FloatDType, Shape, TensorData, TensorMetadata,
7};
8use half::{bf16, f16};
9use std::ops::Range;
10
11impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
12 fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {
13 TchTensor::from_data::<E>(data, (*device).into())
14 }
15
16 fn float_random(
17 shape: Shape,
18 distribution: Distribution,
19 device: &LibTorchDevice,
20 ) -> TchTensor {
21 match distribution {
22 Distribution::Default => {
23 let mut tensor = TchTensor::empty::<E>(shape, *device);
24 tensor
25 .mut_ops(|tensor| tensor.rand_like_out(tensor))
26 .unwrap()
27 }
28 Distribution::Bernoulli(prob) => {
29 let mut tensor = TchTensor::empty::<E>(shape, *device);
30 tensor
31 .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())
32 .unwrap()
33 }
34 Distribution::Uniform(from, to) => {
35 let mut tensor = TchTensor::empty::<E>(shape, *device);
36 tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()
37 }
38 Distribution::Normal(mean, std) => {
39 let mut tensor = TchTensor::empty::<E>(shape, *device);
40 tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()
41 }
42 }
43 }
44
45 fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {
46 TchOps::repeat_dim(tensor, dim, times)
47 }
48
49 fn float_zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor {
50 let shape = TchShape::from(shape);
51 let device: tch::Device = (*device).into();
52
53 TchTensor::new(tch::Tensor::zeros(shape.dims, (E::KIND, device)))
54 }
55
56 fn float_ones(shape: Shape, device: &LibTorchDevice) -> TchTensor {
57 let shape = TchShape::from(shape);
58 let device: tch::Device = (*device).into();
59
60 TchTensor::new(tch::Tensor::ones(shape.dims, (E::KIND, device)))
61 }
62
63 async fn float_into_data(tensor: TchTensor) -> TensorData {
64 let shape = tensor.shape();
65 let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
66 match tensor.tensor.kind() {
67 tch::Kind::Half => {
68 let values: Vec<f16> = tensor.tensor.try_into().unwrap();
69 TensorData::new(values, shape)
70 }
71 tch::Kind::Float => {
72 let values: Vec<f32> = tensor.tensor.try_into().unwrap();
73 TensorData::new(values, shape)
74 }
75 tch::Kind::Double => {
76 let values: Vec<f64> = tensor.tensor.try_into().unwrap();
77 TensorData::new(values, shape)
78 }
79 tch::Kind::BFloat16 => {
80 let values: Vec<bf16> = tensor.tensor.try_into().unwrap();
81 TensorData::new(values, shape)
82 }
83 _ => panic!("Not a valid float kind"),
84 }
85 }
86
87 fn float_device(tensor: &TchTensor) -> LibTorchDevice {
88 tensor.tensor.device().into()
89 }
90
91 fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {
92 TchOps::to_device(tensor, device)
93 }
94
95 fn float_empty(shape: Shape, device: &<LibTorch<E> as Backend>::Device) -> TchTensor {
96 let tensor = tch::Tensor::empty(TchShape::from(shape).dims, (E::KIND, (*device).into()));
97
98 TchTensor::new(tensor)
99 }
100
101 fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
102 TchOps::add(lhs, rhs)
103 }
104
105 fn float_add_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
106 let rhs: f64 = rhs.elem();
107
108 lhs.unary_ops(
109 |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),
110 |tensor| tensor.f_add_scalar(rhs).unwrap(),
111 )
112 }
113
114 fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
115 TchOps::sub(lhs, rhs)
116 }
117
118 fn float_sub_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
119 let rhs: f64 = rhs.elem();
120
121 lhs.unary_ops(
122 |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),
123 |tensor| tensor.f_sub_scalar(rhs).unwrap(),
124 )
125 }
126
127 fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
128 TchOps::mul(lhs, rhs)
129 }
130
131 fn float_mul_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
132 let rhs: f64 = rhs.elem();
133
134 lhs.unary_ops(
135 |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),
136 |tensor| tensor.f_mul_scalar(rhs).unwrap(),
137 )
138 }
139
140 fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
141 TchOps::div(lhs, rhs)
142 }
143
144 fn float_div_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
145 let rhs: f64 = rhs.elem();
146
147 lhs.unary_ops(
148 |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),
149 |tensor| tensor.f_div_scalar(rhs).unwrap(),
150 )
151 }
152
153 fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
154 TchOps::remainder(lhs, rhs)
155 }
156
157 fn float_remainder_scalar(lhs: TchTensor, rhs: E) -> TchTensor {
158 let rhs: f64 = rhs.elem();
159
160 lhs.unary_ops(
161 |tensor| tensor.f_remainder(rhs).unwrap(),
162 |tensor| tensor.f_remainder(rhs).unwrap(),
163 )
164 }
165
166 fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
167 let tensor = lhs.tensor.matmul(&rhs.tensor);
168 TchTensor::new(tensor)
169 }
170
171 fn float_neg(tensor: TchTensor) -> TchTensor {
172 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
173 }
174
175 fn float_recip(tensor: TchTensor) -> TchTensor {
176 TchTensor::new(tensor.tensor.reciprocal())
177 }
178
179 fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {
180 TchOps::swap_dims(tensor, dim1, dim2)
181 }
182
183 fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {
184 TchOps::reshape(tensor, shape)
185 }
186
187 fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {
188 TchOps::gather(dim, tensor, indices)
189 }
190
191 fn float_scatter(
192 dim: usize,
193 tensor: TchTensor,
194 indices: TchTensor,
195 value: TchTensor,
196 ) -> TchTensor {
197 TchOps::scatter(dim, tensor, indices, value)
198 }
199
200 fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
201 TchOps::index_select_dim(tensor, dim, indices)
202 }
203
204 fn float_select_assign(
205 tensor: TchTensor,
206 dim: usize,
207 indices: TchTensor,
208 value: TchTensor,
209 ) -> TchTensor {
210 TchOps::select_assign(tensor, dim, indices, value)
211 }
212
213 fn float_slice(tensor: TchTensor, ranges: &[Range<usize>]) -> TchTensor {
214 TchOps::slice(tensor, ranges)
215 }
216
217 fn float_slice_assign(
218 tensor: TchTensor,
219 ranges: &[Range<usize>],
220 value: TchTensor,
221 ) -> TchTensor {
222 TchOps::slice_assign(tensor, ranges, value)
223 }
224
225 fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {
226 let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);
227
228 TchTensor::new(output)
229 }
230
231 fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: E) -> TchTensor {
232 let value: f64 = value.elem();
233
234 tensor.unary_ops(
235 |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),
236 |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),
237 )
238 }
239
240 fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
241 TchOps::equal(lhs, rhs)
242 }
243
244 fn float_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
245 TchOps::equal_elem(lhs, rhs.elem::<f64>())
246 }
247
248 fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
249 TchOps::greater(lhs, rhs)
250 }
251
252 fn float_greater_elem(lhs: TchTensor, rhs: E) -> TchTensor {
253 TchOps::greater_elem(lhs, rhs.elem::<f64>())
254 }
255
256 fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
257 TchOps::greater_equal(lhs, rhs)
258 }
259
260 fn float_greater_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
261 TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())
262 }
263
264 fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
265 TchOps::lower(lhs, rhs)
266 }
267
268 fn float_lower_elem(lhs: TchTensor, rhs: E) -> TchTensor {
269 TchOps::lower_elem(lhs, rhs.elem::<f64>())
270 }
271
272 fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
273 TchOps::lower_equal(lhs, rhs)
274 }
275
276 fn float_lower_equal_elem(lhs: TchTensor, rhs: E) -> TchTensor {
277 TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())
278 }
279
280 fn float_mean(tensor: TchTensor) -> TchTensor {
281 TchOps::mean(tensor)
282 }
283
284 fn float_sum(tensor: TchTensor) -> TchTensor {
285 TchOps::sum(tensor)
286 }
287
288 fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {
289 TchOps::sum_dim(tensor, dim)
290 }
291
292 fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {
293 TchOps::mean_dim(tensor, dim)
294 }
295
296 fn float_prod(tensor: TchTensor) -> TchTensor {
297 TchOps::prod(tensor)
298 }
299
300 fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
301 TchOps::prod_dim(tensor, dim)
302 }
303
304 fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {
305 TchOps::argmax(tensor, dim)
306 }
307
308 fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
309 TchOps::argmin(tensor, dim)
310 }
311
312 fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {
313 TchOps::max_dim(tensor, dim)
314 }
315
316 fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
317 TchOps::max_dim_with_indices(tensor, dim)
318 }
319
320 fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {
321 TchOps::min_dim(tensor, dim)
322 }
323
324 fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {
325 TchOps::min_dim_with_indices(tensor, dim)
326 }
327
328 fn float_exp(tensor: TchTensor) -> TchTensor {
329 tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())
330 }
331
332 fn float_log(tensor: TchTensor) -> TchTensor {
333 tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())
334 }
335
336 fn float_log1p(tensor: TchTensor) -> TchTensor {
337 tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())
338 }
339
340 fn float_powf_scalar(tensor: TchTensor, value: f32) -> TchTensor {
341 tensor.unary_ops(
342 |mut tensor| tensor.f_pow_(value as f64).unwrap(),
343 |tensor| tensor.pow_tensor_scalar(value as f64),
344 )
345 }
346
347 fn float_sqrt(tensor: TchTensor) -> TchTensor {
348 tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())
349 }
350
351 fn float_abs(tensor: TchTensor) -> TchTensor {
352 tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())
353 }
354
355 fn float_cos(tensor: TchTensor) -> TchTensor {
356 tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())
357 }
358
359 fn float_sin(tensor: TchTensor) -> TchTensor {
360 tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())
361 }
362
363 fn float_tanh(tensor: TchTensor) -> TchTensor {
364 tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
365 }
366
367 fn float_round(tensor: TchTensor) -> TchTensor {
368 tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
369 }
370
371 fn float_floor(tensor: TchTensor) -> TchTensor {
372 tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
373 }
374
375 fn float_ceil(tensor: TchTensor) -> TchTensor {
376 tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
377 }
378
379 fn float_erf(tensor: TchTensor) -> TchTensor {
380 tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
381 }
382
383 fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {
384 TchOps::cat(tensors, dim)
385 }
386
387 fn float_clamp_min(tensor: TchTensor, min: E) -> TchTensor {
388 TchOps::clamp_min(tensor, min.elem::<f64>())
389 }
390
391 fn float_clamp_max(tensor: TchTensor, max: <LibTorch<E> as Backend>::FloatElem) -> TchTensor {
392 TchOps::clamp_max(tensor, max.elem::<f64>())
393 }
394
395 fn float_clamp(
396 tensor: TchTensor,
397 min: <LibTorch<E> as Backend>::FloatElem,
398 max: <LibTorch<E> as Backend>::FloatElem,
399 ) -> TchTensor {
400 TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())
401 }
402
403 fn float_into_int(tensor: TchTensor) -> TchTensor {
404 let tensor = tensor.tensor.to_kind(tch::Kind::Int64);
405 TchTensor::new(tensor)
406 }
407
408 fn float_narrow(tensor: TchTensor, dim: usize, start: usize, length: usize) -> TchTensor {
409 TchOps::narrow(tensor, dim, start, length)
410 }
411
412 fn float_chunk(tensor: TchTensor, chunks: usize, dim: usize) -> Vec<TchTensor> {
413 TchOps::chunk(tensor, chunks, dim)
414 }
415
416 fn float_split(tensor: TchTensor, split_size: usize, dim: usize) -> Vec<TchTensor> {
417 TchOps::split(tensor, split_size, dim)
418 }
419
420 fn float_split_with_sizes(
421 tensor: TchTensor,
422 split_sizes: Vec<usize>,
423 dim: usize,
424 ) -> Vec<TchTensor> {
425 TchOps::split_with_sizes(tensor, split_sizes, dim)
426 }
427
428 fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {
429 TchOps::powf(lhs, rhs)
430 }
431
432 fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {
433 TchOps::permute(tensor, axes)
434 }
435
436 fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {
437 TchOps::flip(tensor, axes)
438 }
439
440 fn float_sign(tensor: TchTensor) -> TchTensor {
441 TchOps::sign(tensor)
442 }
443
444 fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
445 TchOps::expand(tensor, shape)
446 }
447
448 fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {
449 TchOps::sort(tensor, dim, descending)
450 }
451
452 fn float_sort_with_indices(
453 tensor: TchTensor,
454 dim: usize,
455 descending: bool,
456 ) -> (TchTensor, TchTensor) {
457 TchOps::sort_with_indices(tensor, dim, descending)
458 }
459
460 fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {
461 TchOps::argsort(tensor, dim, descending)
462 }
463
464 fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {
465 let kind = match dtype {
470 FloatDType::F64 => tch::Kind::Double,
471 FloatDType::F32 => tch::Kind::Float,
472 FloatDType::F16 => tch::Kind::Half,
473 FloatDType::BF16 => tch::Kind::BFloat16,
474 };
475
476 if tensor.tensor.kind() == kind {
477 tensor
478 } else {
479 TchTensor::new(tensor.tensor.to_kind(kind))
480 }
481 }
482}