1use alloc::vec::Vec;
2use burn_backend::{
3 BoolDType, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice, TensorData,
4 ops::FloatTensorOps,
5 tensor::{BoolTensor, FloatTensor, IntTensor},
6};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl FloatTensorOps<Self> for Dispatch {
12 fn float_from_data(
13 data: burn_backend::TensorData,
14 device: &DispatchDevice,
15 ) -> FloatTensor<Self> {
16 creation_op!(Float, device, |device| B::float_from_data(data, device))
17 }
18
19 fn float_random(
20 shape: Shape,
21 distribution: burn_backend::Distribution,
22 device: &DispatchDevice,
23 dtype: FloatDType,
24 ) -> FloatTensor<Self> {
25 creation_op!(Float, device, |device| {
26 B::float_random(shape, distribution, device, dtype)
27 })
28 }
29
30 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
31 unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
32 }
33
34 fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
35 tensor.device()
36 }
37
38 fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
39 float_to_device!(
40 Float,
41 float,
42 tensor,
43 device,
44 float_to_device,
45 |inner, device| {
46 let data =
47 burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
48 B2::float_from_data(data, device)
49 }
50 )
51 }
52
53 fn float_into_int(tensor: FloatTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
54 unary_float!(tensor, float, |tensor| B::float_into_int(tensor, dtype) => Int)
55 }
56
57 fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
58 creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
59 }
60
61 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
62 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
63 }
64
65 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
66 unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
67 }
68
69 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
70 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
71 }
72
73 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
74 unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
75 }
76
77 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
78 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
79 }
80
81 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
82 unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
83 }
84
85 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
86 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
87 }
88
89 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
90 unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
91 }
92
93 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
94 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
95 }
96
97 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
98 unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
99 }
100
101 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
102 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
103 }
104
105 fn float_cross(
106 lhs: FloatTensor<Self>,
107 rhs: FloatTensor<Self>,
108 dim: usize,
109 ) -> FloatTensor<Self> {
110 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
111 }
112
113 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
114 unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
115 }
116
117 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
118 unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
119 }
120
121 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
122 unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
123 }
124
125 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
126 unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
127 }
128
129 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
130 unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
131 }
132
133 fn float_gather(
134 dim: usize,
135 tensor: FloatTensor<Self>,
136 indices: IntTensor<Self>,
137 ) -> FloatTensor<Self> {
138 binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
139 }
140
141 fn float_scatter_add(
142 dim: usize,
143 tensor: FloatTensor<Self>,
144 indices: IntTensor<Self>,
145 value: FloatTensor<Self>,
146 ) -> FloatTensor<Self> {
147 multi_op!(
148 inputs[(tensor, float), (indices, int), (value, float)], => Float,
149 B::float_scatter_add(dim, tensor, indices, value)
150 )
151 }
152
153 fn float_scatter_nd(
154 data: FloatTensor<Self>,
155 indices: IntTensor<Self>,
156 values: FloatTensor<Self>,
157 reduction: burn_backend::tensor::IndexingUpdateOp,
158 ) -> FloatTensor<Self> {
159 multi_op!(
160 inputs[(data, float), (indices, int), (values, float)], => Float,
161 B::float_scatter_nd(data, indices, values, reduction)
162 )
163 }
164
165 fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
166 binary_float!((data, float), (indices, int), |data, indices| B::float_gather_nd(data, indices) => Float)
167 }
168
169 fn float_select(
170 tensor: FloatTensor<Self>,
171 dim: usize,
172 indices: IntTensor<Self>,
173 ) -> FloatTensor<Self> {
174 binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
175 }
176
177 fn float_select_add(
178 tensor: FloatTensor<Self>,
179 dim: usize,
180 indices: IntTensor<Self>,
181 value: FloatTensor<Self>,
182 ) -> FloatTensor<Self> {
183 multi_op!(
184 inputs[(tensor, float), (indices, int), (value, float)], => Float,
185 B::float_select_add(tensor, dim, indices, value)
186 )
187 }
188
189 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
190 unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
191 }
192
193 fn float_slice_assign(
194 tensor: FloatTensor<Self>,
195 slices: &[Slice],
196 value: FloatTensor<Self>,
197 ) -> FloatTensor<Self> {
198 binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
199 }
200
201 fn float_mask_where(
202 tensor: FloatTensor<Self>,
203 mask: BoolTensor<Self>,
204 value: FloatTensor<Self>,
205 ) -> FloatTensor<Self> {
206 multi_op!(
207 inputs[(tensor, float), (mask, bool), (value, float)], => Float,
208 B::float_mask_where(tensor, mask, value)
209 )
210 }
211
212 fn float_mask_fill(
213 tensor: FloatTensor<Self>,
214 mask: BoolTensor<Self>,
215 value: Scalar,
216 ) -> FloatTensor<Self> {
217 binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
218 }
219
220 fn float_equal(
221 lhs: FloatTensor<Self>,
222 rhs: FloatTensor<Self>,
223 out_dtype: BoolDType,
224 ) -> BoolTensor<Self> {
225 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs, out_dtype) => Bool)
226 }
227
228 fn float_equal_elem(
229 lhs: FloatTensor<Self>,
230 rhs: Scalar,
231 out_dtype: BoolDType,
232 ) -> BoolTensor<Self> {
233 unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs, out_dtype) => Bool)
234 }
235
236 fn float_greater(
237 lhs: FloatTensor<Self>,
238 rhs: FloatTensor<Self>,
239 out_dtype: BoolDType,
240 ) -> BoolTensor<Self> {
241 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs, out_dtype) => Bool)
242 }
243
244 fn float_greater_elem(
245 lhs: FloatTensor<Self>,
246 rhs: Scalar,
247 out_dtype: BoolDType,
248 ) -> BoolTensor<Self> {
249 unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs, out_dtype) => Bool)
250 }
251
252 fn float_greater_equal(
253 lhs: FloatTensor<Self>,
254 rhs: FloatTensor<Self>,
255 out_dtype: BoolDType,
256 ) -> BoolTensor<Self> {
257 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs, out_dtype) => Bool)
258 }
259
260 fn float_greater_equal_elem(
261 lhs: FloatTensor<Self>,
262 rhs: Scalar,
263 out_dtype: BoolDType,
264 ) -> BoolTensor<Self> {
265 unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs, out_dtype) => Bool)
266 }
267
268 fn float_lower(
269 lhs: FloatTensor<Self>,
270 rhs: FloatTensor<Self>,
271 out_dtype: BoolDType,
272 ) -> BoolTensor<Self> {
273 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs, out_dtype) => Bool)
274 }
275
276 fn float_lower_elem(
277 lhs: FloatTensor<Self>,
278 rhs: Scalar,
279 out_dtype: BoolDType,
280 ) -> BoolTensor<Self> {
281 unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs, out_dtype) => Bool)
282 }
283
284 fn float_lower_equal(
285 lhs: FloatTensor<Self>,
286 rhs: FloatTensor<Self>,
287 out_dtype: BoolDType,
288 ) -> BoolTensor<Self> {
289 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs, out_dtype) => Bool)
290 }
291
292 fn float_lower_equal_elem(
293 lhs: FloatTensor<Self>,
294 rhs: Scalar,
295 out_dtype: BoolDType,
296 ) -> BoolTensor<Self> {
297 unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs, out_dtype) => Bool)
298 }
299
300 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
301 unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
302 }
303
304 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
305 unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
306 }
307
308 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
309 unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
310 }
311
312 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
313 unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
314 }
315
316 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
317 unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
318 }
319
320 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
321 unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
322 }
323
324 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
325 unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
326 }
327
328 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
329 unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
330 }
331
332 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
333 unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
334 }
335
336 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
337 unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
338 }
339
340 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
341 unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
342 }
343
344 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
345 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
346 }
347
348 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
349 unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
350 }
351
352 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
353 unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
354 }
355
356 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
357 unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
358 }
359
360 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
361 unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
362 }
363
364 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
365 unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
366 }
367
368 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
369 unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
370 }
371
372 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
373 unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
374 }
375
376 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377 unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
378 }
379
380 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
381 unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
382 }
383
384 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
385 unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
386 }
387
388 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
389 unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
390 }
391
392 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
393 unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
394 }
395
396 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
397 unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
398 }
399
400 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
401 unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
402 }
403
404 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
405 unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
406 }
407
408 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
409 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
410 }
411
412 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413 unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
414 }
415
416 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417 unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
418 }
419
420 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
421 unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
422 }
423
424 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
425 unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
426 }
427
428 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
429 unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
430 }
431
432 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
433 unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim, out_dtype) => Int)
434 }
435
436 fn float_argtopk(
437 tensor: FloatTensor<Self>,
438 dim: usize,
439 k: usize,
440 out_dtype: IntDType,
441 ) -> IntTensor<Self> {
442 unary_float!(tensor, float, |tensor| B::float_argtopk(tensor, dim, k, out_dtype) => Int)
443 }
444
445 fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
446 unary_float!(tensor, float, |tensor| B::float_topk(tensor, dim, k) => Float)
447 }
448
449 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
450 unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim, out_dtype) => Int)
451 }
452
453 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
454 unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
455 }
456
457 fn float_unfold(
458 tensor: FloatTensor<Self>,
459 dim: usize,
460 size: usize,
461 step: usize,
462 ) -> FloatTensor<Self> {
463 unary_float!(tensor, float, |tensor| {
464 B::float_unfold(tensor, dim, size, step)
465 } => Float)
466 }
467
468 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
469 unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
470 }
471
472 fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
473 unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
474 }
475
476 fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
477 unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
478 }
479
480 fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
482 creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
483 }
484
485 fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
486 creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
487 }
488
489 fn float_full(
490 shape: Shape,
491 fill_value: Scalar,
492 device: &DispatchDevice,
493 dtype: FloatDType,
494 ) -> FloatTensor<Self> {
495 creation_op!(Float, device, |device| B::float_full(
496 shape, fill_value, device, dtype
497 ))
498 }
499
500 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
501 unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
502 }
503
504 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
505 unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
506 }
507
508 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
509 unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
510 }
511
512 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
513 unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
514 }
515
516 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
517 unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
518 }
519
520 fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
521 unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
522 }
523
524 fn float_not_equal(
525 lhs: FloatTensor<Self>,
526 rhs: FloatTensor<Self>,
527 out_dtype: BoolDType,
528 ) -> BoolTensor<Self> {
529 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs, out_dtype) => Bool)
530 }
531
532 fn float_not_equal_elem(
533 lhs: FloatTensor<Self>,
534 rhs: Scalar,
535 out_dtype: BoolDType,
536 ) -> BoolTensor<Self> {
537 unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs, out_dtype) => Bool)
538 }
539
540 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
541 unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
542 }
543
544 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
545 unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
546 }
547
548 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
549 unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
550 }
551
552 fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
553 binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
554 }
555
556 fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
557 unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
558 }
559
560 fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
561 unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
562 }
563
564 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
565 vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
566 }
567
568 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
569 unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
570 }
571
572 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
573 unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
574 }
575
576 fn float_max_dim_with_indices(
577 tensor: FloatTensor<Self>,
578 dim: usize,
579 indices_dtype: IntDType,
580 ) -> (FloatTensor<Self>, IntTensor<Self>) {
581 multi_op!(
582 inputs[(tensor, float)],
583 outputs[(out, Float), (indices, Int)],
584 B::float_max_dim_with_indices(tensor, dim, indices_dtype)
585 )
586 }
587
588 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
589 unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
590 }
591
592 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
593 unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
594 }
595
596 fn float_min_dim_with_indices(
597 tensor: FloatTensor<Self>,
598 dim: usize,
599 indices_dtype: IntDType,
600 ) -> (FloatTensor<Self>, IntTensor<Self>) {
601 multi_op!(
602 inputs[(tensor, float)],
603 outputs[(out, Float), (indices, Int)],
604 B::float_min_dim_with_indices(tensor, dim, indices_dtype)
605 )
606 }
607
608 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
609 unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
610 }
611
612 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
613 unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
614 }
615
616 fn float_any(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
617 unary_float!(tensor, float, |tensor| B::float_any(tensor, out_dtype) => Bool)
618 }
619
620 fn float_any_dim(
621 tensor: FloatTensor<Self>,
622 dim: usize,
623 out_dtype: BoolDType,
624 ) -> BoolTensor<Self> {
625 unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim, out_dtype) => Bool)
626 }
627
628 fn float_all(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
629 unary_float!(tensor, float, |tensor| B::float_all(tensor, out_dtype) => Bool)
630 }
631
632 fn float_all_dim(
633 tensor: FloatTensor<Self>,
634 dim: usize,
635 out_dtype: BoolDType,
636 ) -> BoolTensor<Self> {
637 unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim, out_dtype) => Bool)
638 }
639
640 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
641 unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
642 }
643
644 fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
645 unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
646 }
647
648 fn float_sort_with_indices(
649 tensor: FloatTensor<Self>,
650 dim: usize,
651 descending: bool,
652 indices_dtype: IntDType,
653 ) -> (FloatTensor<Self>, IntTensor<Self>) {
654 multi_op!(
655 inputs[(tensor, float)],
656 outputs[(out, Float), (indices, Int)],
657 B::float_sort_with_indices(tensor, dim, descending, indices_dtype)
658 )
659 }
660
661 fn float_argsort(
662 tensor: FloatTensor<Self>,
663 dim: usize,
664 descending: bool,
665 out_dtype: IntDType,
666 ) -> IntTensor<Self> {
667 unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending, out_dtype) => Int)
668 }
669
670 fn float_grid_sample_2d(
671 tensor: FloatTensor<Self>,
672 grid: FloatTensor<Self>,
673 options: burn_backend::ops::GridSampleOptions,
674 ) -> FloatTensor<Self> {
675 binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
676 }
677
678 fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
679 unary_float!(tensor, float, |tensor| B::float_is_nan(tensor, out_dtype) => Bool)
680 }
681
682 fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
683 unary_float!(tensor, float, |tensor| B::float_is_inf(tensor, out_dtype) => Bool)
684 }
685}