1use alloc::vec::Vec;
2use burn_backend::{
3 BoolDType, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice, TensorData,
4 ops::FloatTensorOps,
5 tensor::{BoolTensor, FloatTensor, IntTensor},
6};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl FloatTensorOps<Self> for Dispatch {
12 fn float_from_data(
13 data: burn_backend::TensorData,
14 device: &DispatchDevice,
15 ) -> FloatTensor<Self> {
16 creation_op!(Float, device, |device| B::float_from_data(data, device))
17 }
18
19 fn float_random(
20 shape: Shape,
21 distribution: burn_backend::Distribution,
22 device: &DispatchDevice,
23 dtype: FloatDType,
24 ) -> FloatTensor<Self> {
25 creation_op!(Float, device, |device| {
26 B::float_random(shape, distribution, device, dtype)
27 })
28 }
29
30 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
31 unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
32 }
33
34 fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
35 tensor.device()
36 }
37
38 fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
39 float_to_device!(
40 Float,
41 float,
42 tensor,
43 device,
44 float_to_device,
45 |inner, device| {
46 let data =
47 burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
48 B2::float_from_data(data, device)
49 }
50 )
51 }
52
53 fn float_into_int(tensor: FloatTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
54 unary_float!(tensor, float, |tensor| B::float_into_int(tensor, dtype) => Int)
55 }
56
57 fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
58 creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
59 }
60
61 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
62 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
63 }
64
65 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
66 unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
67 }
68
69 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
70 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
71 }
72
73 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
74 unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
75 }
76
77 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
78 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
79 }
80
81 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
82 unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
83 }
84
85 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
86 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
87 }
88
89 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
90 unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
91 }
92
93 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
94 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
95 }
96
97 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
98 unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
99 }
100
101 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
102 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
103 }
104
105 fn float_cross(
106 lhs: FloatTensor<Self>,
107 rhs: FloatTensor<Self>,
108 dim: usize,
109 ) -> FloatTensor<Self> {
110 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
111 }
112
113 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
114 unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
115 }
116
117 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
118 unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
119 }
120
121 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
122 unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
123 }
124
125 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
126 unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
127 }
128
129 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
130 unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
131 }
132
133 fn float_gather(
134 dim: usize,
135 tensor: FloatTensor<Self>,
136 indices: IntTensor<Self>,
137 ) -> FloatTensor<Self> {
138 binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
139 }
140
141 fn float_scatter_add(
142 dim: usize,
143 tensor: FloatTensor<Self>,
144 indices: IntTensor<Self>,
145 value: FloatTensor<Self>,
146 ) -> FloatTensor<Self> {
147 multi_op!(
148 inputs[(tensor, float), (indices, int), (value, float)], => Float,
149 B::float_scatter_add(dim, tensor, indices, value)
150 )
151 }
152
153 fn float_select(
154 tensor: FloatTensor<Self>,
155 dim: usize,
156 indices: IntTensor<Self>,
157 ) -> FloatTensor<Self> {
158 binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
159 }
160
161 fn float_select_add(
162 tensor: FloatTensor<Self>,
163 dim: usize,
164 indices: IntTensor<Self>,
165 value: FloatTensor<Self>,
166 ) -> FloatTensor<Self> {
167 multi_op!(
168 inputs[(tensor, float), (indices, int), (value, float)], => Float,
169 B::float_select_add(tensor, dim, indices, value)
170 )
171 }
172
173 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
174 unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
175 }
176
177 fn float_slice_assign(
178 tensor: FloatTensor<Self>,
179 slices: &[Slice],
180 value: FloatTensor<Self>,
181 ) -> FloatTensor<Self> {
182 binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
183 }
184
185 fn float_mask_where(
186 tensor: FloatTensor<Self>,
187 mask: BoolTensor<Self>,
188 value: FloatTensor<Self>,
189 ) -> FloatTensor<Self> {
190 multi_op!(
191 inputs[(tensor, float), (mask, bool), (value, float)], => Float,
192 B::float_mask_where(tensor, mask, value)
193 )
194 }
195
196 fn float_mask_fill(
197 tensor: FloatTensor<Self>,
198 mask: BoolTensor<Self>,
199 value: Scalar,
200 ) -> FloatTensor<Self> {
201 binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
202 }
203
204 fn float_equal(
205 lhs: FloatTensor<Self>,
206 rhs: FloatTensor<Self>,
207 out_dtype: BoolDType,
208 ) -> BoolTensor<Self> {
209 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs, out_dtype) => Bool)
210 }
211
212 fn float_equal_elem(
213 lhs: FloatTensor<Self>,
214 rhs: Scalar,
215 out_dtype: BoolDType,
216 ) -> BoolTensor<Self> {
217 unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs, out_dtype) => Bool)
218 }
219
220 fn float_greater(
221 lhs: FloatTensor<Self>,
222 rhs: FloatTensor<Self>,
223 out_dtype: BoolDType,
224 ) -> BoolTensor<Self> {
225 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs, out_dtype) => Bool)
226 }
227
228 fn float_greater_elem(
229 lhs: FloatTensor<Self>,
230 rhs: Scalar,
231 out_dtype: BoolDType,
232 ) -> BoolTensor<Self> {
233 unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs, out_dtype) => Bool)
234 }
235
236 fn float_greater_equal(
237 lhs: FloatTensor<Self>,
238 rhs: FloatTensor<Self>,
239 out_dtype: BoolDType,
240 ) -> BoolTensor<Self> {
241 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs, out_dtype) => Bool)
242 }
243
244 fn float_greater_equal_elem(
245 lhs: FloatTensor<Self>,
246 rhs: Scalar,
247 out_dtype: BoolDType,
248 ) -> BoolTensor<Self> {
249 unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs, out_dtype) => Bool)
250 }
251
252 fn float_lower(
253 lhs: FloatTensor<Self>,
254 rhs: FloatTensor<Self>,
255 out_dtype: BoolDType,
256 ) -> BoolTensor<Self> {
257 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs, out_dtype) => Bool)
258 }
259
260 fn float_lower_elem(
261 lhs: FloatTensor<Self>,
262 rhs: Scalar,
263 out_dtype: BoolDType,
264 ) -> BoolTensor<Self> {
265 unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs, out_dtype) => Bool)
266 }
267
268 fn float_lower_equal(
269 lhs: FloatTensor<Self>,
270 rhs: FloatTensor<Self>,
271 out_dtype: BoolDType,
272 ) -> BoolTensor<Self> {
273 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs, out_dtype) => Bool)
274 }
275
276 fn float_lower_equal_elem(
277 lhs: FloatTensor<Self>,
278 rhs: Scalar,
279 out_dtype: BoolDType,
280 ) -> BoolTensor<Self> {
281 unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs, out_dtype) => Bool)
282 }
283
284 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
285 unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
286 }
287
288 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
289 unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
290 }
291
292 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
293 unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
294 }
295
296 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
297 unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
298 }
299
300 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
301 unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
302 }
303
304 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
305 unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
306 }
307
308 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
309 unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
310 }
311
312 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
313 unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
314 }
315
316 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
317 unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
318 }
319
320 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
321 unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
322 }
323
324 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
325 unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
326 }
327
328 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
329 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
330 }
331
332 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
333 unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
334 }
335
336 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
337 unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
338 }
339
340 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
341 unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
342 }
343
344 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
345 unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
346 }
347
348 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
349 unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
350 }
351
352 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
353 unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
354 }
355
356 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
357 unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
358 }
359
360 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
361 unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
362 }
363
364 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
365 unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
366 }
367
368 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
369 unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
370 }
371
372 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
373 unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
374 }
375
376 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377 unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
378 }
379
380 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
381 unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
382 }
383
384 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
385 unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
386 }
387
388 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
389 unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
390 }
391
392 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
393 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
394 }
395
396 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
397 unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
398 }
399
400 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
401 unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
402 }
403
404 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
405 unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
406 }
407
408 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
409 unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
410 }
411
412 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413 unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
414 }
415
416 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
417 unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim, out_dtype) => Int)
418 }
419
420 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
421 unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim, out_dtype) => Int)
422 }
423
424 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
425 unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
426 }
427
428 fn float_unfold(
429 tensor: FloatTensor<Self>,
430 dim: usize,
431 size: usize,
432 step: usize,
433 ) -> FloatTensor<Self> {
434 unary_float!(tensor, float, |tensor| {
435 B::float_unfold(tensor, dim, size, step)
436 } => Float)
437 }
438
439 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
440 unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
441 }
442
443 fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
444 unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
445 }
446
447 fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
448 unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
449 }
450
451 fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
453 creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
454 }
455
456 fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
457 creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
458 }
459
460 fn float_full(
461 shape: Shape,
462 fill_value: Scalar,
463 device: &DispatchDevice,
464 dtype: FloatDType,
465 ) -> FloatTensor<Self> {
466 creation_op!(Float, device, |device| B::float_full(
467 shape, fill_value, device, dtype
468 ))
469 }
470
471 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
472 unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
473 }
474
475 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
476 unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
477 }
478
479 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
480 unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
481 }
482
483 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
484 unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
485 }
486
487 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
488 unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
489 }
490
491 fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
492 unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
493 }
494
495 fn float_not_equal(
496 lhs: FloatTensor<Self>,
497 rhs: FloatTensor<Self>,
498 out_dtype: BoolDType,
499 ) -> BoolTensor<Self> {
500 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs, out_dtype) => Bool)
501 }
502
503 fn float_not_equal_elem(
504 lhs: FloatTensor<Self>,
505 rhs: Scalar,
506 out_dtype: BoolDType,
507 ) -> BoolTensor<Self> {
508 unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs, out_dtype) => Bool)
509 }
510
511 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
512 unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
513 }
514
515 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
516 unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
517 }
518
519 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
520 unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
521 }
522
523 fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
524 binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
525 }
526
527 fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
528 unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
529 }
530
531 fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
532 unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
533 }
534
535 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
536 vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
537 }
538
539 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
540 unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
541 }
542
543 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
544 unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
545 }
546
547 fn float_max_dim_with_indices(
548 tensor: FloatTensor<Self>,
549 dim: usize,
550 indices_dtype: IntDType,
551 ) -> (FloatTensor<Self>, IntTensor<Self>) {
552 multi_op!(
553 inputs[(tensor, float)],
554 outputs[(out, Float), (indices, Int)],
555 B::float_max_dim_with_indices(tensor, dim, indices_dtype)
556 )
557 }
558
559 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
560 unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
561 }
562
563 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
564 unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
565 }
566
567 fn float_min_dim_with_indices(
568 tensor: FloatTensor<Self>,
569 dim: usize,
570 indices_dtype: IntDType,
571 ) -> (FloatTensor<Self>, IntTensor<Self>) {
572 multi_op!(
573 inputs[(tensor, float)],
574 outputs[(out, Float), (indices, Int)],
575 B::float_min_dim_with_indices(tensor, dim, indices_dtype)
576 )
577 }
578
579 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
580 unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
581 }
582
583 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
584 unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
585 }
586
587 fn float_any(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
588 unary_float!(tensor, float, |tensor| B::float_any(tensor, out_dtype) => Bool)
589 }
590
591 fn float_any_dim(
592 tensor: FloatTensor<Self>,
593 dim: usize,
594 out_dtype: BoolDType,
595 ) -> BoolTensor<Self> {
596 unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim, out_dtype) => Bool)
597 }
598
599 fn float_all(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
600 unary_float!(tensor, float, |tensor| B::float_all(tensor, out_dtype) => Bool)
601 }
602
603 fn float_all_dim(
604 tensor: FloatTensor<Self>,
605 dim: usize,
606 out_dtype: BoolDType,
607 ) -> BoolTensor<Self> {
608 unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim, out_dtype) => Bool)
609 }
610
611 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
612 unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
613 }
614
615 fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
616 unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
617 }
618
619 fn float_sort_with_indices(
620 tensor: FloatTensor<Self>,
621 dim: usize,
622 descending: bool,
623 indices_dtype: IntDType,
624 ) -> (FloatTensor<Self>, IntTensor<Self>) {
625 multi_op!(
626 inputs[(tensor, float)],
627 outputs[(out, Float), (indices, Int)],
628 B::float_sort_with_indices(tensor, dim, descending, indices_dtype)
629 )
630 }
631
632 fn float_argsort(
633 tensor: FloatTensor<Self>,
634 dim: usize,
635 descending: bool,
636 out_dtype: IntDType,
637 ) -> IntTensor<Self> {
638 unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending, out_dtype) => Int)
639 }
640
641 fn float_grid_sample_2d(
642 tensor: FloatTensor<Self>,
643 grid: FloatTensor<Self>,
644 options: burn_backend::ops::GridSampleOptions,
645 ) -> FloatTensor<Self> {
646 binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
647 }
648
649 fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
650 unary_float!(tensor, float, |tensor| B::float_is_nan(tensor, out_dtype) => Bool)
651 }
652
653 fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
654 unary_float!(tensor, float, |tensor| B::float_is_inf(tensor, out_dtype) => Bool)
655 }
656}