1use burn_backend::{
2 ExecutionError, Scalar, TensorData,
3 ops::FloatTensorOps,
4 tensor::{BoolTensor, FloatTensor, IntTensor},
5};
6use burn_std::{FloatDType, Shape, Slice};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl FloatTensorOps<Self> for Dispatch {
15 fn float_from_data(
16 data: burn_backend::TensorData,
17 device: &DispatchDevice,
18 ) -> FloatTensor<Self> {
19 creation_op!(Float, device, |device| B::float_from_data(data, device))
20 }
21
22 fn float_random(
23 shape: Shape,
24 distribution: burn_backend::Distribution,
25 device: &DispatchDevice,
26 ) -> FloatTensor<Self> {
27 creation_op!(Float, device, |device| {
28 B::float_random(shape, distribution, device)
29 })
30 }
31
32 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
33 unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)
34 }
35
36 fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {
37 tensor.device()
38 }
39
40 fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {
41 float_to_device!(
42 Float,
43 float,
44 tensor,
45 device,
46 float_to_device,
47 |inner, device| {
48 let data =
49 burn_backend::read_sync(B1::float_into_data(inner)).expect("Should read data");
50 B2::float_from_data(data, device)
51 }
52 )
53 }
54
55 fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
56 unary_float!(tensor, float, |tensor| B::float_into_int(tensor) => Int)
57 }
58
59 fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
60 creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))
61 }
62
63 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
64 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)
65 }
66
67 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
68 unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)
69 }
70
71 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
72 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)
73 }
74
75 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
76 unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)
77 }
78
79 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
80 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)
81 }
82
83 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
84 unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)
85 }
86
87 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
88 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)
89 }
90
91 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
92 unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)
93 }
94
95 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
96 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)
97 }
98
99 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
100 unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)
101 }
102
103 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
104 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)
105 }
106
107 fn float_cross(
108 lhs: FloatTensor<Self>,
109 rhs: FloatTensor<Self>,
110 dim: usize,
111 ) -> FloatTensor<Self> {
112 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)
113 }
114
115 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
116 unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)
117 }
118
119 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
120 unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)
121 }
122
123 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
124 unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)
125 }
126
127 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
128 unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)
129 }
130
131 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
132 unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)
133 }
134
135 fn float_gather(
136 dim: usize,
137 tensor: FloatTensor<Self>,
138 indices: IntTensor<Self>,
139 ) -> FloatTensor<Self> {
140 binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)
141 }
142
143 fn float_scatter_add(
144 dim: usize,
145 tensor: FloatTensor<Self>,
146 indices: IntTensor<Self>,
147 value: FloatTensor<Self>,
148 ) -> FloatTensor<Self> {
149 multi_op!(
150 inputs[(tensor, float), (indices, int), (value, float)], => Float,
151 B::float_scatter_add(dim, tensor, indices, value)
152 )
153 }
154
155 fn float_select(
156 tensor: FloatTensor<Self>,
157 dim: usize,
158 indices: IntTensor<Self>,
159 ) -> FloatTensor<Self> {
160 binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)
161 }
162
163 fn float_select_add(
164 tensor: FloatTensor<Self>,
165 dim: usize,
166 indices: IntTensor<Self>,
167 value: FloatTensor<Self>,
168 ) -> FloatTensor<Self> {
169 multi_op!(
170 inputs[(tensor, float), (indices, int), (value, float)], => Float,
171 B::float_select_add(tensor, dim, indices, value)
172 )
173 }
174
175 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
176 unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)
177 }
178
179 fn float_slice_assign(
180 tensor: FloatTensor<Self>,
181 slices: &[Slice],
182 value: FloatTensor<Self>,
183 ) -> FloatTensor<Self> {
184 binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)
185 }
186
187 fn float_mask_where(
188 tensor: FloatTensor<Self>,
189 mask: BoolTensor<Self>,
190 value: FloatTensor<Self>,
191 ) -> FloatTensor<Self> {
192 multi_op!(
193 inputs[(tensor, float), (mask, bool), (value, float)], => Float,
194 B::float_mask_where(tensor, mask, value)
195 )
196 }
197
198 fn float_mask_fill(
199 tensor: FloatTensor<Self>,
200 mask: BoolTensor<Self>,
201 value: Scalar,
202 ) -> FloatTensor<Self> {
203 binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)
204 }
205
206 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
207 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs) => Bool)
208 }
209
210 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
211 unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs) => Bool)
212 }
213
214 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
215 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs) => Bool)
216 }
217
218 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
219 unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs) => Bool)
220 }
221
222 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
223 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs) => Bool)
224 }
225
226 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
227 unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs) => Bool)
228 }
229
230 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
231 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs) => Bool)
232 }
233
234 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
235 unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs) => Bool)
236 }
237
238 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
239 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs) => Bool)
240 }
241
242 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
243 unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs) => Bool)
244 }
245
246 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
247 unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)
248 }
249
250 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
251 unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)
252 }
253
254 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
255 unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)
256 }
257
258 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
259 unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)
260 }
261
262 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
263 unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)
264 }
265
266 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
267 unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)
268 }
269
270 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
271 unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)
272 }
273
274 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
275 unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)
276 }
277
278 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
279 unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)
280 }
281
282 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
283 unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)
284 }
285
286 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
287 unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)
288 }
289
290 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
291 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)
292 }
293
294 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
295 unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)
296 }
297
298 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
299 unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)
300 }
301
302 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
303 unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)
304 }
305
306 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
307 unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)
308 }
309
310 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
311 unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)
312 }
313
314 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
315 unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)
316 }
317
318 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
319 unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)
320 }
321
322 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
323 unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)
324 }
325
326 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
327 unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)
328 }
329
330 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
331 unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)
332 }
333
334 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
335 unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)
336 }
337
338 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
339 unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)
340 }
341
342 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
343 unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)
344 }
345
346 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
347 unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)
348 }
349
350 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351 unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)
352 }
353
354 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
355 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)
356 }
357
358 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
359 unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)
360 }
361
362 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
363 unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)
364 }
365
366 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
367 unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)
368 }
369
370 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
371 unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)
372 }
373
374 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
375 unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)
376 }
377
378 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
379 unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim) => Int)
380 }
381
382 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
383 unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim) => Int)
384 }
385
386 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
387 unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)
388 }
389
390 fn float_unfold(
391 tensor: FloatTensor<Self>,
392 dim: usize,
393 size: usize,
394 step: usize,
395 ) -> FloatTensor<Self> {
396 unary_float!(tensor, float, |tensor| {
397 B::float_unfold(tensor, dim, size, step)
398 } => Float)
399 }
400
401 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402 unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)
403 }
404
405 fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {
406 unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)
407 }
408
409 fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {
410 unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))
411 }
412
413 fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
415 creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))
416 }
417
418 fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {
419 creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))
420 }
421
422 fn float_full(
423 shape: Shape,
424 fill_value: Scalar,
425 device: &DispatchDevice,
426 dtype: FloatDType,
427 ) -> FloatTensor<Self> {
428 creation_op!(Float, device, |device| B::float_full(
429 shape, fill_value, device, dtype
430 ))
431 }
432
433 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
434 unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)
435 }
436
437 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
438 unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)
439 }
440
441 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
442 unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)
443 }
444
445 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
446 unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)
447 }
448
449 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450 unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)
451 }
452
453 fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
454 unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)
455 }
456
457 fn float_not_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
458 binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs) => Bool)
459 }
460
461 fn float_not_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
462 unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs) => Bool)
463 }
464
465 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
466 unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)
467 }
468
469 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
470 unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)
471 }
472
473 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
474 unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)
475 }
476
477 fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
478 binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)
479 }
480
481 fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
482 unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)
483 }
484
485 fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
486 unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)
487 }
488
489 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
490 vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)
491 }
492
493 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
494 unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)
495 }
496
497 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
498 unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)
499 }
500
501 fn float_max_dim_with_indices(
502 tensor: FloatTensor<Self>,
503 dim: usize,
504 ) -> (FloatTensor<Self>, IntTensor<Self>) {
505 multi_op!(
506 inputs[(tensor, float)],
507 outputs[(out, Float), (indices, Int)],
508 B::float_max_dim_with_indices(tensor, dim)
509 )
510 }
511
512 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
513 unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)
514 }
515
516 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
517 unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)
518 }
519
520 fn float_min_dim_with_indices(
521 tensor: FloatTensor<Self>,
522 dim: usize,
523 ) -> (FloatTensor<Self>, IntTensor<Self>) {
524 multi_op!(
525 inputs[(tensor, float)],
526 outputs[(out, Float), (indices, Int)],
527 B::float_min_dim_with_indices(tensor, dim)
528 )
529 }
530
531 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
532 unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)
533 }
534
535 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
536 unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)
537 }
538
539 fn float_any(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
540 unary_float!(tensor, float, |tensor| B::float_any(tensor) => Bool)
541 }
542
543 fn float_any_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
544 unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim) => Bool)
545 }
546
547 fn float_all(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
548 unary_float!(tensor, float, |tensor| B::float_all(tensor) => Bool)
549 }
550
551 fn float_all_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {
552 unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim) => Bool)
553 }
554
555 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
556 unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)
557 }
558
559 fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {
560 unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)
561 }
562
563 fn float_sort_with_indices(
564 tensor: FloatTensor<Self>,
565 dim: usize,
566 descending: bool,
567 ) -> (FloatTensor<Self>, IntTensor<Self>) {
568 multi_op!(
569 inputs[(tensor, float)],
570 outputs[(out, Float), (indices, Int)],
571 B::float_sort_with_indices(tensor, dim, descending)
572 )
573 }
574
575 fn float_argsort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
576 unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending) => Int)
577 }
578
579 fn float_grid_sample_2d(
580 tensor: FloatTensor<Self>,
581 grid: FloatTensor<Self>,
582 options: burn_backend::ops::GridSampleOptions,
583 ) -> FloatTensor<Self> {
584 binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)
585 }
586
587 fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
588 unary_float!(tensor, float, |tensor| B::float_is_nan(tensor) => Bool)
589 }
590
591 fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
592 unary_float!(tensor, float, |tensor| B::float_is_inf(tensor) => Bool)
593 }
594}