Skip to main content

burn_dispatch/ops/
tensor.rs

1use alloc::vec::Vec;
2use burn_backend::{
3    BoolDType, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice, TensorData,
4    ops::FloatTensorOps,
5    tensor::{BoolTensor, FloatTensor, IntTensor},
6};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl FloatTensorOps<Self> for Dispatch {
12    fn float_from_data(
13        data: burn_backend::TensorData,
14        device: &DispatchDevice,
15    ) -> FloatTensor<Self> {
16        creation_op!(Float, device, |device| B::float_from_data(data, device))
17    }
18
19    fn float_random(
20        shape: Shape,
21        distribution: burn_backend::Distribution,
22        device: &DispatchDevice,
23        dtype: FloatDType,
24    ) -> FloatTensor<Self> {
25        creation_op!(Float, device, |device| {
26            B::float_random(shape, distribution, device, dtype)
27        })
28    }
29
30    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
31        unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
32    }
33
34    fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
35        tensor.device()
36    }
37
38    fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
39        float_to_device!(
40            Float,
41            float,
42            tensor,
43            device,
44            float_to_device,
45            |inner, device| {
46                let data =
47                    burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
48                B2::float_from_data(data, device)
49            }
50        )
51    }
52
53    fn float_into_int(tensor: FloatTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
54        unary_float!(tensor, float, |tensor| B::float_into_int(tensor, dtype) => Int)
55    }
56
57    fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
58        creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
59    }
60
61    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
62        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
63    }
64
65    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
66        unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
67    }
68
69    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
70        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
71    }
72
73    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
74        unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
75    }
76
77    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
78        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
79    }
80
81    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
82        unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
83    }
84
85    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
86        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
87    }
88
89    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
90        unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
91    }
92
93    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
94        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
95    }
96
97    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
98        unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
99    }
100
101    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
102        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
103    }
104
105    fn float_cross(
106        lhs: FloatTensor<Self>,
107        rhs: FloatTensor<Self>,
108        dim: usize,
109    ) -> FloatTensor<Self> {
110        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
111    }
112
113    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
114        unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
115    }
116
117    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
118        unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
119    }
120
121    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
122        unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
123    }
124
125    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
126        unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
127    }
128
129    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
130        unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
131    }
132
133    fn float_gather(
134        dim: usize,
135        tensor: FloatTensor<Self>,
136        indices: IntTensor<Self>,
137    ) -> FloatTensor<Self> {
138        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
139    }
140
141    fn float_scatter_add(
142        dim: usize,
143        tensor: FloatTensor<Self>,
144        indices: IntTensor<Self>,
145        value: FloatTensor<Self>,
146    ) -> FloatTensor<Self> {
147        multi_op!(
148            inputs[(tensor, float), (indices, int), (value, float)], => Float,
149            B::float_scatter_add(dim, tensor, indices, value)
150        )
151    }
152
153    fn float_select(
154        tensor: FloatTensor<Self>,
155        dim: usize,
156        indices: IntTensor<Self>,
157    ) -> FloatTensor<Self> {
158        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
159    }
160
161    fn float_select_add(
162        tensor: FloatTensor<Self>,
163        dim: usize,
164        indices: IntTensor<Self>,
165        value: FloatTensor<Self>,
166    ) -> FloatTensor<Self> {
167        multi_op!(
168            inputs[(tensor, float), (indices, int), (value, float)], => Float,
169            B::float_select_add(tensor, dim, indices, value)
170        )
171    }
172
173    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
174        unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
175    }
176
177    fn float_slice_assign(
178        tensor: FloatTensor<Self>,
179        slices: &[Slice],
180        value: FloatTensor<Self>,
181    ) -> FloatTensor<Self> {
182        binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
183    }
184
185    fn float_mask_where(
186        tensor: FloatTensor<Self>,
187        mask: BoolTensor<Self>,
188        value: FloatTensor<Self>,
189    ) -> FloatTensor<Self> {
190        multi_op!(
191            inputs[(tensor, float), (mask, bool), (value, float)], => Float,
192            B::float_mask_where(tensor, mask, value)
193        )
194    }
195
196    fn float_mask_fill(
197        tensor: FloatTensor<Self>,
198        mask: BoolTensor<Self>,
199        value: Scalar,
200    ) -> FloatTensor<Self> {
201        binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
202    }
203
204    fn float_equal(
205        lhs: FloatTensor<Self>,
206        rhs: FloatTensor<Self>,
207        out_dtype: BoolDType,
208    ) -> BoolTensor<Self> {
209        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs, out_dtype) => Bool)
210    }
211
212    fn float_equal_elem(
213        lhs: FloatTensor<Self>,
214        rhs: Scalar,
215        out_dtype: BoolDType,
216    ) -> BoolTensor<Self> {
217        unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs, out_dtype) => Bool)
218    }
219
220    fn float_greater(
221        lhs: FloatTensor<Self>,
222        rhs: FloatTensor<Self>,
223        out_dtype: BoolDType,
224    ) -> BoolTensor<Self> {
225        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs, out_dtype) => Bool)
226    }
227
228    fn float_greater_elem(
229        lhs: FloatTensor<Self>,
230        rhs: Scalar,
231        out_dtype: BoolDType,
232    ) -> BoolTensor<Self> {
233        unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs, out_dtype) => Bool)
234    }
235
236    fn float_greater_equal(
237        lhs: FloatTensor<Self>,
238        rhs: FloatTensor<Self>,
239        out_dtype: BoolDType,
240    ) -> BoolTensor<Self> {
241        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs, out_dtype) => Bool)
242    }
243
244    fn float_greater_equal_elem(
245        lhs: FloatTensor<Self>,
246        rhs: Scalar,
247        out_dtype: BoolDType,
248    ) -> BoolTensor<Self> {
249        unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs, out_dtype) => Bool)
250    }
251
252    fn float_lower(
253        lhs: FloatTensor<Self>,
254        rhs: FloatTensor<Self>,
255        out_dtype: BoolDType,
256    ) -> BoolTensor<Self> {
257        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs, out_dtype) => Bool)
258    }
259
260    fn float_lower_elem(
261        lhs: FloatTensor<Self>,
262        rhs: Scalar,
263        out_dtype: BoolDType,
264    ) -> BoolTensor<Self> {
265        unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs, out_dtype) => Bool)
266    }
267
268    fn float_lower_equal(
269        lhs: FloatTensor<Self>,
270        rhs: FloatTensor<Self>,
271        out_dtype: BoolDType,
272    ) -> BoolTensor<Self> {
273        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs, out_dtype) => Bool)
274    }
275
276    fn float_lower_equal_elem(
277        lhs: FloatTensor<Self>,
278        rhs: Scalar,
279        out_dtype: BoolDType,
280    ) -> BoolTensor<Self> {
281        unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs, out_dtype) => Bool)
282    }
283
284    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
285        unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
286    }
287
288    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
289        unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
290    }
291
292    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
293        unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
294    }
295
296    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
297        unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
298    }
299
300    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
301        unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
302    }
303
304    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
305        unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
306    }
307
308    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
309        unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
310    }
311
312    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
313        unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
314    }
315
316    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
317        unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
318    }
319
320    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
321        unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
322    }
323
324    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
325        unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
326    }
327
328    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
329        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
330    }
331
332    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
333        unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
334    }
335
336    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
337        unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
338    }
339
340    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
341        unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
342    }
343
344    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
345        unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
346    }
347
348    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
349        unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
350    }
351
352    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
353        unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
354    }
355
356    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
357        unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
358    }
359
360    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
361        unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
362    }
363
364    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
365        unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
366    }
367
368    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
369        unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
370    }
371
372    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
373        unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
374    }
375
376    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377        unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
378    }
379
380    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
381        unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
382    }
383
384    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
385        unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
386    }
387
388    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
389        unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
390    }
391
392    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
393        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
394    }
395
396    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
397        unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
398    }
399
400    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
401        unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
402    }
403
404    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
405        unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
406    }
407
408    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
409        unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
410    }
411
412    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413        unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
414    }
415
416    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
417        unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim, out_dtype) => Int)
418    }
419
420    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
421        unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim, out_dtype) => Int)
422    }
423
424    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
425        unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
426    }
427
428    fn float_unfold(
429        tensor: FloatTensor<Self>,
430        dim: usize,
431        size: usize,
432        step: usize,
433    ) -> FloatTensor<Self> {
434        unary_float!(tensor, float, |tensor| {
435            B::float_unfold(tensor, dim, size, step)
436        } => Float)
437    }
438
439    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
440        unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
441    }
442
443    fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
444        unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
445    }
446
447    fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
448        unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
449    }
450
451    // Default implementation
452    fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
453        creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
454    }
455
456    fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
457        creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
458    }
459
460    fn float_full(
461        shape: Shape,
462        fill_value: Scalar,
463        device: &DispatchDevice,
464        dtype: FloatDType,
465    ) -> FloatTensor<Self> {
466        creation_op!(Float, device, |device| B::float_full(
467            shape, fill_value, device, dtype
468        ))
469    }
470
471    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
472        unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
473    }
474
475    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
476        unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
477    }
478
479    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
480        unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
481    }
482
483    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
484        unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
485    }
486
487    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
488        unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
489    }
490
491    fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
492        unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
493    }
494
495    fn float_not_equal(
496        lhs: FloatTensor<Self>,
497        rhs: FloatTensor<Self>,
498        out_dtype: BoolDType,
499    ) -> BoolTensor<Self> {
500        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs, out_dtype) => Bool)
501    }
502
503    fn float_not_equal_elem(
504        lhs: FloatTensor<Self>,
505        rhs: Scalar,
506        out_dtype: BoolDType,
507    ) -> BoolTensor<Self> {
508        unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs, out_dtype) => Bool)
509    }
510
511    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
512        unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
513    }
514
515    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
516        unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
517    }
518
519    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
520        unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
521    }
522
523    fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
524        binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
525    }
526
527    fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
528        unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
529    }
530
531    fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
532        unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
533    }
534
535    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
536        vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
537    }
538
539    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
540        unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
541    }
542
543    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
544        unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
545    }
546
547    fn float_max_dim_with_indices(
548        tensor: FloatTensor<Self>,
549        dim: usize,
550        indices_dtype: IntDType,
551    ) -> (FloatTensor<Self>, IntTensor<Self>) {
552        multi_op!(
553            inputs[(tensor, float)],
554            outputs[(out, Float), (indices, Int)],
555            B::float_max_dim_with_indices(tensor, dim, indices_dtype)
556        )
557    }
558
559    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
560        unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
561    }
562
563    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
564        unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
565    }
566
567    fn float_min_dim_with_indices(
568        tensor: FloatTensor<Self>,
569        dim: usize,
570        indices_dtype: IntDType,
571    ) -> (FloatTensor<Self>, IntTensor<Self>) {
572        multi_op!(
573            inputs[(tensor, float)],
574            outputs[(out, Float), (indices, Int)],
575            B::float_min_dim_with_indices(tensor, dim, indices_dtype)
576        )
577    }
578
579    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
580        unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
581    }
582
583    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
584        unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
585    }
586
587    fn float_any(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
588        unary_float!(tensor, float, |tensor| B::float_any(tensor, out_dtype) => Bool)
589    }
590
591    fn float_any_dim(
592        tensor: FloatTensor<Self>,
593        dim: usize,
594        out_dtype: BoolDType,
595    ) -> BoolTensor<Self> {
596        unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim, out_dtype) => Bool)
597    }
598
599    fn float_all(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
600        unary_float!(tensor, float, |tensor| B::float_all(tensor, out_dtype) => Bool)
601    }
602
603    fn float_all_dim(
604        tensor: FloatTensor<Self>,
605        dim: usize,
606        out_dtype: BoolDType,
607    ) -> BoolTensor<Self> {
608        unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim, out_dtype) => Bool)
609    }
610
611    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
612        unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
613    }
614
615    fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
616        unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
617    }
618
619    fn float_sort_with_indices(
620        tensor: FloatTensor<Self>,
621        dim: usize,
622        descending: bool,
623        indices_dtype: IntDType,
624    ) -> (FloatTensor<Self>, IntTensor<Self>) {
625        multi_op!(
626            inputs[(tensor, float)],
627            outputs[(out, Float), (indices, Int)],
628            B::float_sort_with_indices(tensor, dim, descending, indices_dtype)
629        )
630    }
631
632    fn float_argsort(
633        tensor: FloatTensor<Self>,
634        dim: usize,
635        descending: bool,
636        out_dtype: IntDType,
637    ) -> IntTensor<Self> {
638        unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending, out_dtype) => Int)
639    }
640
641    fn float_grid_sample_2d(
642        tensor: FloatTensor<Self>,
643        grid: FloatTensor<Self>,
644        options: burn_backend::ops::GridSampleOptions,
645    ) -> FloatTensor<Self> {
646        binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
647    }
648
649    fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
650        unary_float!(tensor, float, |tensor| B::float_is_nan(tensor, out_dtype) => Bool)
651    }
652
653    fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
654        unary_float!(tensor, float, |tensor| B::float_is_inf(tensor, out_dtype) => Bool)
655    }
656}