Skip to main content

burn_dispatch/ops/
tensor.rs

1use burn_backend::{
2    ExecutionError, Scalar, TensorData,
3    ops::FloatTensorOps,
4    tensor::{BoolTensor, FloatTensor, IntTensor},
5};
6use burn_std::{FloatDType, Shape, Slice};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11// TODO: remove backend default elem type genericsnow that we have per-device defaults
12// https://github.com/tracel-ai/burn/issues/3642
13
14impl FloatTensorOps<Self> for Dispatch {
15    fn float_from_data(
16        data: burn_backend::TensorData,
17        device: &DispatchDevice,
18    ) -> FloatTensor<Self> {
19        creation_op!(Float, device, |device| B::float_from_data(data, device))
20    }
21
22    fn float_random(
23        shape: Shape,
24        distribution: burn_backend::Distribution,
25        device: &DispatchDevice,
26    ) -> FloatTensor<Self> {
27        creation_op!(Float, device, |device| {
28            B::float_random(shape, distribution, device)
29        })
30    }
31
32    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
33        unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
34    }
35
36    fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
37        tensor.device()
38    }
39
40    fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
41        float_to_device!(
42            Float,
43            float,
44            tensor,
45            device,
46            float_to_device,
47            |inner, device| {
48                let data =
49                    burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
50                B2::float_from_data(data, device)
51            }
52        )
53    }
54
55    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
56        unary_float!(tensor, float, |tensor| B::float_into_int(tensor) => Int)
57    }
58
59    fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
60        creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
61    }
62
63    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
64        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
65    }
66
67    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
68        unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
69    }
70
71    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
72        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
73    }
74
75    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
76        unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
77    }
78
79    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
80        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
81    }
82
83    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
84        unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
85    }
86
87    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
88        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
89    }
90
91    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
92        unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
93    }
94
95    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
96        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
97    }
98
99    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
100        unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
101    }
102
103    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
104        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
105    }
106
107    fn float_cross(
108        lhs: FloatTensor<Self>,
109        rhs: FloatTensor<Self>,
110        dim: usize,
111    ) -> FloatTensor<Self> {
112        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
113    }
114
115    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
116        unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
117    }
118
119    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
120        unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
121    }
122
123    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
124        unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
125    }
126
127    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
128        unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
129    }
130
131    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
132        unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
133    }
134
135    fn float_gather(
136        dim: usize,
137        tensor: FloatTensor<Self>,
138        indices: IntTensor<Self>,
139    ) -> FloatTensor<Self> {
140        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
141    }
142
143    fn float_scatter_add(
144        dim: usize,
145        tensor: FloatTensor<Self>,
146        indices: IntTensor<Self>,
147        value: FloatTensor<Self>,
148    ) -> FloatTensor<Self> {
149        multi_op!(
150            inputs[(tensor, float), (indices, int), (value, float)], => Float,
151            B::float_scatter_add(dim, tensor, indices, value)
152        )
153    }
154
155    fn float_select(
156        tensor: FloatTensor<Self>,
157        dim: usize,
158        indices: IntTensor<Self>,
159    ) -> FloatTensor<Self> {
160        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
161    }
162
163    fn float_select_add(
164        tensor: FloatTensor<Self>,
165        dim: usize,
166        indices: IntTensor<Self>,
167        value: FloatTensor<Self>,
168    ) -> FloatTensor<Self> {
169        multi_op!(
170            inputs[(tensor, float), (indices, int), (value, float)], => Float,
171            B::float_select_add(tensor, dim, indices, value)
172        )
173    }
174
175    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
176        unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
177    }
178
179    fn float_slice_assign(
180        tensor: FloatTensor<Self>,
181        slices: &[Slice],
182        value: FloatTensor<Self>,
183    ) -> FloatTensor<Self> {
184        binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
185    }
186
187    fn float_mask_where(
188        tensor: FloatTensor<Self>,
189        mask: BoolTensor<Self>,
190        value: FloatTensor<Self>,
191    ) -> FloatTensor<Self> {
192        multi_op!(
193            inputs[(tensor, float), (mask, bool), (value, float)], => Float,
194            B::float_mask_where(tensor, mask, value)
195        )
196    }
197
198    fn float_mask_fill(
199        tensor: FloatTensor<Self>,
200        mask: BoolTensor<Self>,
201        value: Scalar,
202    ) -> FloatTensor<Self> {
203        binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
204    }
205
206    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
207        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs) => Bool)
208    }
209
210    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
211        unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs) => Bool)
212    }
213
214    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
215        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs) => Bool)
216    }
217
218    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
219        unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs) => Bool)
220    }
221
222    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
223        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs) => Bool)
224    }
225
226    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
227        unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs) => Bool)
228    }
229
230    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
231        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs) => Bool)
232    }
233
234    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
235        unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs) => Bool)
236    }
237
238    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
239        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs) => Bool)
240    }
241
242    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
243        unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs) => Bool)
244    }
245
246    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
247        unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
248    }
249
250    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
251        unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
252    }
253
254    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
255        unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
256    }
257
258    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
259        unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
260    }
261
262    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
263        unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
264    }
265
266    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
267        unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
268    }
269
270    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
271        unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
272    }
273
274    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
275        unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
276    }
277
278    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
279        unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
280    }
281
282    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
283        unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
284    }
285
286    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
287        unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
288    }
289
290    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
291        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
292    }
293
294    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
295        unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
296    }
297
298    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
299        unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
300    }
301
302    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
303        unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
304    }
305
306    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
307        unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
308    }
309
310    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
311        unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
312    }
313
314    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
315        unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
316    }
317
318    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
319        unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
320    }
321
322    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
323        unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
324    }
325
326    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
327        unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
328    }
329
330    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
331        unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
332    }
333
334    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
335        unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
336    }
337
338    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
339        unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
340    }
341
342    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
343        unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
344    }
345
346    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
347        unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
348    }
349
350    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351        unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
352    }
353
354    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
355        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
356    }
357
358    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
359        unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
360    }
361
362    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
363        unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
364    }
365
366    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
367        unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
368    }
369
370    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
371        unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
372    }
373
374    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
375        unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
376    }
377
378    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
379        unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim) => Int)
380    }
381
382    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
383        unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim) => Int)
384    }
385
386    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
387        unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
388    }
389
390    fn float_unfold(
391        tensor: FloatTensor<Self>,
392        dim: usize,
393        size: usize,
394        step: usize,
395    ) -> FloatTensor<Self> {
396        unary_float!(tensor, float, |tensor| {
397            B::float_unfold(tensor, dim, size, step)
398        } => Float)
399    }
400
401    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402        unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
403    }
404
405    fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
406        unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
407    }
408
409    fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
410        unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
411    }
412
413    // Default implementation
414    fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
415        creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
416    }
417
418    fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
419        creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
420    }
421
422    fn float_full(
423        shape: Shape,
424        fill_value: Scalar,
425        device: &DispatchDevice,
426        dtype: FloatDType,
427    ) -> FloatTensor<Self> {
428        creation_op!(Float, device, |device| B::float_full(
429            shape, fill_value, device, dtype
430        ))
431    }
432
433    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
434        unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
435    }
436
437    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
438        unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
439    }
440
441    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
442        unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
443    }
444
445    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
446        unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
447    }
448
449    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450        unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
451    }
452
453    fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
454        unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
455    }
456
457    fn float_not_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
458        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs) => Bool)
459    }
460
461    fn float_not_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
462        unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs) => Bool)
463    }
464
465    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
466        unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
467    }
468
469    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
470        unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
471    }
472
473    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
474        unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
475    }
476
477    fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
478        binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
479    }
480
481    fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
482        unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
483    }
484
485    fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
486        unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
487    }
488
489    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
490        vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
491    }
492
493    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
494        unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
495    }
496
497    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
498        unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
499    }
500
501    fn float_max_dim_with_indices(
502        tensor: FloatTensor<Self>,
503        dim: usize,
504    ) -> (FloatTensor<Self>, IntTensor<Self>) {
505        multi_op!(
506            inputs[(tensor, float)],
507            outputs[(out, Float), (indices, Int)],
508            B::float_max_dim_with_indices(tensor, dim)
509        )
510    }
511
512    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
513        unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
514    }
515
516    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
517        unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
518    }
519
520    fn float_min_dim_with_indices(
521        tensor: FloatTensor<Self>,
522        dim: usize,
523    ) -> (FloatTensor<Self>, IntTensor<Self>) {
524        multi_op!(
525            inputs[(tensor, float)],
526            outputs[(out, Float), (indices, Int)],
527            B::float_min_dim_with_indices(tensor, dim)
528        )
529    }
530
531    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
532        unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
533    }
534
535    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
536        unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
537    }
538
539    fn float_any(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
540        unary_float!(tensor, float, |tensor| B::float_any(tensor) => Bool)
541    }
542
543    fn float_any_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
544        unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim) => Bool)
545    }
546
547    fn float_all(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
548        unary_float!(tensor, float, |tensor| B::float_all(tensor) => Bool)
549    }
550
551    fn float_all_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
552        unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim) => Bool)
553    }
554
555    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
556        unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
557    }
558
559    fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
560        unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
561    }
562
563    fn float_sort_with_indices(
564        tensor: FloatTensor<Self>,
565        dim: usize,
566        descending: bool,
567    ) -> (FloatTensor<Self>, IntTensor<Self>) {
568        multi_op!(
569            inputs[(tensor, float)],
570            outputs[(out, Float), (indices, Int)],
571            B::float_sort_with_indices(tensor, dim, descending)
572        )
573    }
574
575    fn float_argsort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
576        unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending) => Int)
577    }
578
579    fn float_grid_sample_2d(
580        tensor: FloatTensor<Self>,
581        grid: FloatTensor<Self>,
582        options: burn_backend::ops::GridSampleOptions,
583    ) -> FloatTensor<Self> {
584        binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
585    }
586
587    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
588        unary_float!(tensor, float, |tensor| B::float_is_nan(tensor) => Bool)
589    }
590
591    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
592        unary_float!(tensor, float, |tensor| B::float_is_inf(tensor) => Bool)
593    }
594}