Skip to main content

burn_dispatch/ops/
tensor.rs

1use alloc::vec::Vec;
2use burn_backend::{
3    BoolDType, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice, TensorData,
4    ops::FloatTensorOps,
5    tensor::{BoolTensor, FloatTensor, IntTensor},
6};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl FloatTensorOps<Self> for Dispatch {
12    fn float_from_data(
13        data: burn_backend::TensorData,
14        device: &DispatchDevice,
15    ) -> FloatTensor<Self> {
16        creation_op!(Float, device, |device| B::float_from_data(data, device))
17    }
18
19    fn float_random(
20        shape: Shape,
21        distribution: burn_backend::Distribution,
22        device: &DispatchDevice,
23        dtype: FloatDType,
24    ) -> FloatTensor<Self> {
25        creation_op!(Float, device, |device| {
26            B::float_random(shape, distribution, device, dtype)
27        })
28    }
29
30    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
31        unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
32    }
33
34    fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
35        tensor.device()
36    }
37
38    fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
39        float_to_device!(
40            Float,
41            float,
42            tensor,
43            device,
44            float_to_device,
45            |inner, device| {
46                let data =
47                    burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
48                B2::float_from_data(data, device)
49            }
50        )
51    }
52
53    fn float_into_int(tensor: FloatTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
54        unary_float!(tensor, float, |tensor| B::float_into_int(tensor, dtype) => Int)
55    }
56
57    fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
58        creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
59    }
60
61    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
62        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
63    }
64
65    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
66        unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
67    }
68
69    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
70        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
71    }
72
73    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
74        unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
75    }
76
77    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
78        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
79    }
80
81    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
82        unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
83    }
84
85    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
86        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
87    }
88
89    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
90        unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
91    }
92
93    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
94        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
95    }
96
97    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
98        unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
99    }
100
101    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
102        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
103    }
104
105    fn float_cross(
106        lhs: FloatTensor<Self>,
107        rhs: FloatTensor<Self>,
108        dim: usize,
109    ) -> FloatTensor<Self> {
110        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
111    }
112
113    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
114        unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
115    }
116
117    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
118        unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
119    }
120
121    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
122        unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
123    }
124
125    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
126        unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
127    }
128
129    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
130        unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
131    }
132
133    fn float_gather(
134        dim: usize,
135        tensor: FloatTensor<Self>,
136        indices: IntTensor<Self>,
137    ) -> FloatTensor<Self> {
138        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
139    }
140
141    fn float_scatter_add(
142        dim: usize,
143        tensor: FloatTensor<Self>,
144        indices: IntTensor<Self>,
145        value: FloatTensor<Self>,
146    ) -> FloatTensor<Self> {
147        multi_op!(
148            inputs[(tensor, float), (indices, int), (value, float)], => Float,
149            B::float_scatter_add(dim, tensor, indices, value)
150        )
151    }
152
153    fn float_scatter_nd(
154        data: FloatTensor<Self>,
155        indices: IntTensor<Self>,
156        values: FloatTensor<Self>,
157        reduction: burn_backend::tensor::IndexingUpdateOp,
158    ) -> FloatTensor<Self> {
159        multi_op!(
160            inputs[(data, float), (indices, int), (values, float)], => Float,
161            B::float_scatter_nd(data, indices, values, reduction)
162        )
163    }
164
165    fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
166        binary_float!((data, float), (indices, int), |data, indices| B::float_gather_nd(data, indices) => Float)
167    }
168
169    fn float_select(
170        tensor: FloatTensor<Self>,
171        dim: usize,
172        indices: IntTensor<Self>,
173    ) -> FloatTensor<Self> {
174        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
175    }
176
177    fn float_select_add(
178        tensor: FloatTensor<Self>,
179        dim: usize,
180        indices: IntTensor<Self>,
181        value: FloatTensor<Self>,
182    ) -> FloatTensor<Self> {
183        multi_op!(
184            inputs[(tensor, float), (indices, int), (value, float)], => Float,
185            B::float_select_add(tensor, dim, indices, value)
186        )
187    }
188
189    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
190        unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
191    }
192
193    fn float_slice_assign(
194        tensor: FloatTensor<Self>,
195        slices: &[Slice],
196        value: FloatTensor<Self>,
197    ) -> FloatTensor<Self> {
198        binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
199    }
200
201    fn float_mask_where(
202        tensor: FloatTensor<Self>,
203        mask: BoolTensor<Self>,
204        value: FloatTensor<Self>,
205    ) -> FloatTensor<Self> {
206        multi_op!(
207            inputs[(tensor, float), (mask, bool), (value, float)], => Float,
208            B::float_mask_where(tensor, mask, value)
209        )
210    }
211
212    fn float_mask_fill(
213        tensor: FloatTensor<Self>,
214        mask: BoolTensor<Self>,
215        value: Scalar,
216    ) -> FloatTensor<Self> {
217        binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
218    }
219
220    fn float_equal(
221        lhs: FloatTensor<Self>,
222        rhs: FloatTensor<Self>,
223        out_dtype: BoolDType,
224    ) -> BoolTensor<Self> {
225        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs, out_dtype) => Bool)
226    }
227
228    fn float_equal_elem(
229        lhs: FloatTensor<Self>,
230        rhs: Scalar,
231        out_dtype: BoolDType,
232    ) -> BoolTensor<Self> {
233        unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs, out_dtype) => Bool)
234    }
235
236    fn float_greater(
237        lhs: FloatTensor<Self>,
238        rhs: FloatTensor<Self>,
239        out_dtype: BoolDType,
240    ) -> BoolTensor<Self> {
241        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs, out_dtype) => Bool)
242    }
243
244    fn float_greater_elem(
245        lhs: FloatTensor<Self>,
246        rhs: Scalar,
247        out_dtype: BoolDType,
248    ) -> BoolTensor<Self> {
249        unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs, out_dtype) => Bool)
250    }
251
252    fn float_greater_equal(
253        lhs: FloatTensor<Self>,
254        rhs: FloatTensor<Self>,
255        out_dtype: BoolDType,
256    ) -> BoolTensor<Self> {
257        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs, out_dtype) => Bool)
258    }
259
260    fn float_greater_equal_elem(
261        lhs: FloatTensor<Self>,
262        rhs: Scalar,
263        out_dtype: BoolDType,
264    ) -> BoolTensor<Self> {
265        unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs, out_dtype) => Bool)
266    }
267
268    fn float_lower(
269        lhs: FloatTensor<Self>,
270        rhs: FloatTensor<Self>,
271        out_dtype: BoolDType,
272    ) -> BoolTensor<Self> {
273        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs, out_dtype) => Bool)
274    }
275
276    fn float_lower_elem(
277        lhs: FloatTensor<Self>,
278        rhs: Scalar,
279        out_dtype: BoolDType,
280    ) -> BoolTensor<Self> {
281        unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs, out_dtype) => Bool)
282    }
283
284    fn float_lower_equal(
285        lhs: FloatTensor<Self>,
286        rhs: FloatTensor<Self>,
287        out_dtype: BoolDType,
288    ) -> BoolTensor<Self> {
289        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs, out_dtype) => Bool)
290    }
291
292    fn float_lower_equal_elem(
293        lhs: FloatTensor<Self>,
294        rhs: Scalar,
295        out_dtype: BoolDType,
296    ) -> BoolTensor<Self> {
297        unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs, out_dtype) => Bool)
298    }
299
300    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
301        unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
302    }
303
304    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
305        unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
306    }
307
308    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
309        unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
310    }
311
312    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
313        unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
314    }
315
316    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
317        unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
318    }
319
320    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
321        unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
322    }
323
324    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
325        unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
326    }
327
328    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
329        unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
330    }
331
332    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
333        unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
334    }
335
336    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
337        unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
338    }
339
340    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
341        unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
342    }
343
344    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
345        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
346    }
347
348    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
349        unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
350    }
351
352    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
353        unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
354    }
355
356    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
357        unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
358    }
359
360    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
361        unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
362    }
363
364    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
365        unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
366    }
367
368    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
369        unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
370    }
371
372    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
373        unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
374    }
375
376    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377        unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
378    }
379
380    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
381        unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
382    }
383
384    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
385        unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
386    }
387
388    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
389        unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
390    }
391
392    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
393        unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
394    }
395
396    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
397        unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
398    }
399
400    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
401        unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
402    }
403
404    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
405        unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
406    }
407
408    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
409        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
410    }
411
412    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413        unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
414    }
415
416    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417        unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
418    }
419
420    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
421        unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
422    }
423
424    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
425        unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
426    }
427
428    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
429        unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
430    }
431
432    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
433        unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim, out_dtype) => Int)
434    }
435
436    fn float_argtopk(
437        tensor: FloatTensor<Self>,
438        dim: usize,
439        k: usize,
440        out_dtype: IntDType,
441    ) -> IntTensor<Self> {
442        unary_float!(tensor, float, |tensor| B::float_argtopk(tensor, dim, k, out_dtype) => Int)
443    }
444
445    fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
446        unary_float!(tensor, float, |tensor| B::float_topk(tensor, dim, k) => Float)
447    }
448
449    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
450        unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim, out_dtype) => Int)
451    }
452
453    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
454        unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
455    }
456
457    fn float_unfold(
458        tensor: FloatTensor<Self>,
459        dim: usize,
460        size: usize,
461        step: usize,
462    ) -> FloatTensor<Self> {
463        unary_float!(tensor, float, |tensor| {
464            B::float_unfold(tensor, dim, size, step)
465        } => Float)
466    }
467
468    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
469        unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
470    }
471
472    fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
473        unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
474    }
475
476    fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
477        unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
478    }
479
480    // Default implementation
481    fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
482        creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
483    }
484
485    fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
486        creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
487    }
488
489    fn float_full(
490        shape: Shape,
491        fill_value: Scalar,
492        device: &DispatchDevice,
493        dtype: FloatDType,
494    ) -> FloatTensor<Self> {
495        creation_op!(Float, device, |device| B::float_full(
496            shape, fill_value, device, dtype
497        ))
498    }
499
500    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
501        unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
502    }
503
504    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
505        unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
506    }
507
508    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
509        unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
510    }
511
512    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
513        unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
514    }
515
516    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
517        unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
518    }
519
520    fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
521        unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
522    }
523
524    fn float_not_equal(
525        lhs: FloatTensor<Self>,
526        rhs: FloatTensor<Self>,
527        out_dtype: BoolDType,
528    ) -> BoolTensor<Self> {
529        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs, out_dtype) => Bool)
530    }
531
532    fn float_not_equal_elem(
533        lhs: FloatTensor<Self>,
534        rhs: Scalar,
535        out_dtype: BoolDType,
536    ) -> BoolTensor<Self> {
537        unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs, out_dtype) => Bool)
538    }
539
540    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
541        unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
542    }
543
544    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
545        unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
546    }
547
548    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
549        unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
550    }
551
552    fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
553        binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
554    }
555
556    fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
557        unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
558    }
559
560    fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
561        unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
562    }
563
564    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
565        vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
566    }
567
568    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
569        unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
570    }
571
572    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
573        unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
574    }
575
576    fn float_max_dim_with_indices(
577        tensor: FloatTensor<Self>,
578        dim: usize,
579        indices_dtype: IntDType,
580    ) -> (FloatTensor<Self>, IntTensor<Self>) {
581        multi_op!(
582            inputs[(tensor, float)],
583            outputs[(out, Float), (indices, Int)],
584            B::float_max_dim_with_indices(tensor, dim, indices_dtype)
585        )
586    }
587
588    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
589        unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
590    }
591
592    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
593        unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
594    }
595
596    fn float_min_dim_with_indices(
597        tensor: FloatTensor<Self>,
598        dim: usize,
599        indices_dtype: IntDType,
600    ) -> (FloatTensor<Self>, IntTensor<Self>) {
601        multi_op!(
602            inputs[(tensor, float)],
603            outputs[(out, Float), (indices, Int)],
604            B::float_min_dim_with_indices(tensor, dim, indices_dtype)
605        )
606    }
607
608    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
609        unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
610    }
611
612    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
613        unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
614    }
615
616    fn float_any(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
617        unary_float!(tensor, float, |tensor| B::float_any(tensor, out_dtype) => Bool)
618    }
619
620    fn float_any_dim(
621        tensor: FloatTensor<Self>,
622        dim: usize,
623        out_dtype: BoolDType,
624    ) -> BoolTensor<Self> {
625        unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim, out_dtype) => Bool)
626    }
627
628    fn float_all(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
629        unary_float!(tensor, float, |tensor| B::float_all(tensor, out_dtype) => Bool)
630    }
631
632    fn float_all_dim(
633        tensor: FloatTensor<Self>,
634        dim: usize,
635        out_dtype: BoolDType,
636    ) -> BoolTensor<Self> {
637        unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim, out_dtype) => Bool)
638    }
639
640    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
641        unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
642    }
643
644    fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
645        unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
646    }
647
648    fn float_sort_with_indices(
649        tensor: FloatTensor<Self>,
650        dim: usize,
651        descending: bool,
652        indices_dtype: IntDType,
653    ) -> (FloatTensor<Self>, IntTensor<Self>) {
654        multi_op!(
655            inputs[(tensor, float)],
656            outputs[(out, Float), (indices, Int)],
657            B::float_sort_with_indices(tensor, dim, descending, indices_dtype)
658        )
659    }
660
661    fn float_argsort(
662        tensor: FloatTensor<Self>,
663        dim: usize,
664        descending: bool,
665        out_dtype: IntDType,
666    ) -> IntTensor<Self> {
667        unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending, out_dtype) => Int)
668    }
669
670    fn float_grid_sample_2d(
671        tensor: FloatTensor<Self>,
672        grid: FloatTensor<Self>,
673        options: burn_backend::ops::GridSampleOptions,
674    ) -> FloatTensor<Self> {
675        binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
676    }
677
678    fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
679        unary_float!(tensor, float, |tensor| B::float_is_nan(tensor, out_dtype) => Bool)
680    }
681
682    fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
683        unary_float!(tensor, float, |tensor| B::float_is_inf(tensor, out_dtype) => Bool)
684    }
685}