Skip to main content

burn_router/ops/
tensor.rs

1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3use burn_backend::{Scalar, tensor::FloatElem};
4use burn_std::{BoolDType, IntDType};
5
6use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
7use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor};
8use burn_backend::{Distribution, FloatDType, Shape, Slice, TensorData, ops::FloatTensorOps};
9use burn_ir::{
10    BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, CrossOpIr, DimOpIr,
11    FlipOpIr, FloatOperationIr, FullOpIr, GatherNdOpIr, GatherOpIr, InitOperationIr, MaskFillOpIr,
12    MaskWhereOpIr, MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr,
13    RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr,
14    ScatterNdOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr,
15    SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
16};
17
18impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
19    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
20        let client = get_client::<R>(device);
21        let out = client.register_tensor_data(data);
22        let desc = InitOperationIr {
23            out: out.to_ir_out(),
24        };
25
26        // Call register op when output is already initialized
27        client.register_op(OperationIr::Init(desc));
28
29        out
30    }
31
32    fn float_random(
33        shape: Shape,
34        distribution: Distribution,
35        device: &Device<Self>,
36        dtype: FloatDType,
37    ) -> FloatTensor<Self> {
38        let client = get_client::<R>(device);
39        let dtype = dtype.into();
40        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
41
42        client
43            .register(OperationIr::Float(dtype, FloatOperationIr::Random(desc)))
44            .output()
45    }
46
47    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
48        let client = get_client::<R>(device);
49        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
50
51        client
52            .register(OperationIr::BaseFloat(BaseOperationIr::Zeros(desc)))
53            .output()
54    }
55
56    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
57        let client = get_client::<R>(device);
58        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
59
60        client
61            .register(OperationIr::BaseFloat(BaseOperationIr::Ones(desc)))
62            .output()
63    }
64
65    fn float_full(
66        shape: Shape,
67        fill_value: Scalar,
68        device: &Device<Self>,
69        dtype: FloatDType,
70    ) -> FloatTensor<Self> {
71        let client = get_client::<R>(device);
72        let dtype = dtype.into();
73        let value = fill_value.into();
74        let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
75
76        client
77            .register(OperationIr::NumericFloat(
78                desc.out.dtype,
79                NumericOperationIr::Full(desc),
80            ))
81            .output()
82    }
83
84    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
85        Ok(tensor
86            .into_data()
87            .await?
88            // Since underlying backends can have different data types, we convert to the current elem
89            .convert::<FloatElem<Self>>())
90    }
91
92    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
93        tensor.client.device()
94    }
95
96    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
97        if &tensor.client.device() == device {
98            return tensor;
99        }
100        R::change_client_backend(tensor, device)
101    }
102
103    fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
104        let client = tensor.client.clone();
105        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
106            client.create_empty_handle()
107        });
108
109        client
110            .register(OperationIr::Float(
111                desc.input.dtype,
112                FloatOperationIr::IntoInt(desc),
113            ))
114            .output()
115    }
116
117    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
118        let client = get_client::<R>(device);
119        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
120
121        client
122            .register(OperationIr::BaseFloat(BaseOperationIr::Empty(desc)))
123            .output()
124    }
125
126    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
127        let client = lhs.client.clone();
128        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
129            client.create_empty_handle()
130        });
131
132        client
133            .register(OperationIr::NumericFloat(
134                desc.out.dtype,
135                NumericOperationIr::Add(desc),
136            ))
137            .output()
138    }
139
140    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
141        let client = lhs.client.clone();
142        let rhs = rhs.into();
143        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
144
145        client
146            .register(OperationIr::NumericFloat(
147                desc.out.dtype,
148                NumericOperationIr::AddScalar(desc),
149            ))
150            .output()
151    }
152
153    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
154        let client = tensor.client.clone();
155        let min = min.into();
156        let max = max.into();
157        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
158
159        client
160            .register(OperationIr::NumericFloat(
161                desc.out.dtype,
162                NumericOperationIr::Clamp(desc),
163            ))
164            .output()
165    }
166
167    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
168        let client = lhs.client.clone();
169        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
170            client.create_empty_handle()
171        });
172
173        client
174            .register(OperationIr::NumericFloat(
175                desc.out.dtype,
176                NumericOperationIr::Sub(desc),
177            ))
178            .output()
179    }
180
181    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
182        let client = lhs.client.clone();
183        let rhs = rhs.into();
184        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
185
186        client
187            .register(OperationIr::NumericFloat(
188                desc.out.dtype,
189                NumericOperationIr::SubScalar(desc),
190            ))
191            .output()
192    }
193
194    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
195        let client = lhs.client.clone();
196        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
197            client.create_empty_handle()
198        });
199
200        client
201            .register(OperationIr::NumericFloat(
202                desc.out.dtype,
203                NumericOperationIr::Mul(desc),
204            ))
205            .output()
206    }
207
208    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
209        let client = lhs.client.clone();
210        let rhs = rhs.into();
211        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
212
213        client
214            .register(OperationIr::NumericFloat(
215                desc.out.dtype,
216                NumericOperationIr::MulScalar(desc),
217            ))
218            .output()
219    }
220
221    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
222        let client = lhs.client.clone();
223        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
224            client.create_empty_handle()
225        });
226
227        client
228            .register(OperationIr::NumericFloat(
229                desc.out.dtype,
230                NumericOperationIr::Div(desc),
231            ))
232            .output()
233    }
234
235    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
236        let client = lhs.client.clone();
237        let rhs = rhs.into();
238        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
239
240        client
241            .register(OperationIr::NumericFloat(
242                desc.out.dtype,
243                NumericOperationIr::DivScalar(desc),
244            ))
245            .output()
246    }
247
248    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
249        let client = lhs.client.clone();
250        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
251            client.create_empty_handle()
252        });
253
254        client
255            .register(OperationIr::NumericFloat(
256                desc.out.dtype,
257                NumericOperationIr::Rem(desc),
258            ))
259            .output()
260    }
261
262    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
263        let client = lhs.client.clone();
264        let rhs = rhs.into();
265        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
266
267        client
268            .register(OperationIr::NumericFloat(
269                desc.out.dtype,
270                NumericOperationIr::RemScalar(desc),
271            ))
272            .output()
273    }
274
275    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
276        let client = lhs.client.clone();
277        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
278            client.create_empty_handle()
279        });
280
281        client
282            .register(OperationIr::Float(
283                desc.out.dtype,
284                FloatOperationIr::Matmul(desc),
285            ))
286            .output()
287    }
288
289    fn float_cross(
290        lhs: FloatTensor<Self>,
291        rhs: FloatTensor<Self>,
292        dim: usize,
293    ) -> FloatTensor<Self> {
294        let client = lhs.client.clone();
295        let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {
296            client.create_empty_handle()
297        });
298
299        client
300            .register(OperationIr::Float(
301                desc.out.dtype,
302                FloatOperationIr::Cross(desc),
303            ))
304            .output()
305    }
306
307    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
308        let client = tensor.client.clone();
309        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
310            client.create_empty_handle()
311        });
312
313        client
314            .register(OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc)))
315            .output()
316    }
317
318    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
319        let client = tensor.client.clone();
320        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
321
322        client
323            .register(OperationIr::BaseFloat(BaseOperationIr::Reshape(desc)))
324            .output()
325    }
326
327    fn float_gather(
328        dim: usize,
329        tensor: FloatTensor<Self>,
330        indices: IntTensor<Self>,
331    ) -> FloatTensor<Self> {
332        let client = tensor.client.clone();
333        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
334            client.create_empty_handle()
335        });
336
337        client
338            .register(OperationIr::BaseFloat(BaseOperationIr::Gather(desc)))
339            .output()
340    }
341
342    fn float_scatter_add(
343        dim: usize,
344        tensor: FloatTensor<Self>,
345        indices: IntTensor<Self>,
346        value: FloatTensor<Self>,
347    ) -> FloatTensor<Self> {
348        let client = tensor.client.clone();
349        let desc = ScatterOpIr::create(
350            tensor.into_ir(),
351            dim,
352            indices.into_ir(),
353            value.into_ir(),
354            IndexingUpdateOp::Add,
355            || client.create_empty_handle(),
356        );
357
358        client
359            .register(OperationIr::BaseFloat(BaseOperationIr::Scatter(desc)))
360            .output()
361    }
362
363    fn float_scatter_nd(
364        data: FloatTensor<Self>,
365        indices: IntTensor<Self>,
366        values: FloatTensor<Self>,
367        reduction: IndexingUpdateOp,
368    ) -> FloatTensor<Self> {
369        let client = data.client.clone();
370        let desc = ScatterNdOpIr::create(
371            data.into_ir(),
372            indices.into_ir(),
373            values.into_ir(),
374            reduction,
375            || client.create_empty_handle(),
376        );
377
378        client
379            .register(OperationIr::BaseFloat(BaseOperationIr::ScatterNd(desc)))
380            .output()
381    }
382
383    fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
384        let client = data.client.clone();
385        let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
386            client.create_empty_handle()
387        });
388
389        client
390            .register(OperationIr::BaseFloat(BaseOperationIr::GatherNd(desc)))
391            .output()
392    }
393
394    fn float_select(
395        tensor: FloatTensor<Self>,
396        dim: usize,
397        indices: IntTensor<Self>,
398    ) -> FloatTensor<Self> {
399        let client = tensor.client.clone();
400        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
401            client.create_empty_handle()
402        });
403
404        client
405            .register(OperationIr::BaseFloat(BaseOperationIr::Select(desc)))
406            .output()
407    }
408
409    fn float_select_add(
410        tensor: FloatTensor<Self>,
411        dim: usize,
412        indices: IntTensor<Self>,
413        value: FloatTensor<Self>,
414    ) -> FloatTensor<Self> {
415        let client = tensor.client.clone();
416        let desc = SelectAssignOpIr::create(
417            tensor.into_ir(),
418            dim,
419            indices.into_ir(),
420            value.into_ir(),
421            IndexingUpdateOp::Add,
422            || client.create_empty_handle(),
423        );
424
425        client
426            .register(OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc)))
427            .output()
428    }
429
430    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
431        let client = tensor.client.clone();
432        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
433            client.create_empty_handle()
434        });
435
436        client
437            .register(OperationIr::BaseFloat(BaseOperationIr::Slice(desc)))
438            .output()
439    }
440
441    fn float_slice_assign(
442        tensor: FloatTensor<Self>,
443        slices: &[burn_backend::Slice],
444        value: FloatTensor<Self>,
445    ) -> FloatTensor<Self> {
446        let client = tensor.client.clone();
447        let desc =
448            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
449                client.create_empty_handle()
450            });
451
452        client
453            .register(OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc)))
454            .output()
455    }
456
457    fn float_mask_where(
458        tensor: FloatTensor<Self>,
459        mask: BoolTensor<Self>,
460        value: FloatTensor<Self>,
461    ) -> FloatTensor<Self> {
462        let client = tensor.client.clone();
463        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
464            client.create_empty_handle()
465        });
466
467        client
468            .register(OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc)))
469            .output()
470    }
471
472    fn float_mask_fill(
473        tensor: FloatTensor<Self>,
474        mask: BoolTensor<Self>,
475        value: Scalar,
476    ) -> FloatTensor<Self> {
477        let client = tensor.client.clone();
478        let value = value.into();
479        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
480            client.create_empty_handle()
481        });
482
483        client
484            .register(OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc)))
485            .output()
486    }
487
488    fn float_equal(
489        lhs: FloatTensor<Self>,
490        rhs: FloatTensor<Self>,
491        out_dtype: BoolDType,
492    ) -> BoolTensor<Self> {
493        let client = lhs.client.clone();
494        let desc =
495            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
496                client.create_empty_handle()
497            });
498
499        client
500            .register(OperationIr::BaseFloat(BaseOperationIr::Equal(desc)))
501            .output()
502    }
503
504    fn float_equal_elem(
505        lhs: FloatTensor<Self>,
506        rhs: Scalar,
507        out_dtype: BoolDType,
508    ) -> BoolTensor<Self> {
509        let client = lhs.client.clone();
510        let rhs = rhs.into();
511        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
512            client.create_empty_handle()
513        });
514
515        client
516            .register(OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc)))
517            .output()
518    }
519
520    fn float_greater(
521        lhs: FloatTensor<Self>,
522        rhs: FloatTensor<Self>,
523        out_dtype: BoolDType,
524    ) -> BoolTensor<Self> {
525        let client = lhs.client.clone();
526        let desc =
527            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
528                client.create_empty_handle()
529            });
530
531        client
532            .register(OperationIr::NumericFloat(
533                desc.lhs.dtype,
534                NumericOperationIr::Greater(desc),
535            ))
536            .output()
537    }
538
539    fn float_greater_elem(
540        lhs: FloatTensor<Self>,
541        rhs: Scalar,
542        out_dtype: BoolDType,
543    ) -> BoolTensor<Self> {
544        let client = lhs.client.clone();
545        let rhs = rhs.into();
546        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
547            client.create_empty_handle()
548        });
549
550        client
551            .register(OperationIr::NumericFloat(
552                desc.lhs.dtype,
553                NumericOperationIr::GreaterElem(desc),
554            ))
555            .output()
556    }
557
558    fn float_greater_equal(
559        lhs: FloatTensor<Self>,
560        rhs: FloatTensor<Self>,
561        out_dtype: BoolDType,
562    ) -> BoolTensor<Self> {
563        let client = lhs.client.clone();
564        let desc =
565            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
566                client.create_empty_handle()
567            });
568
569        client
570            .register(OperationIr::NumericFloat(
571                desc.lhs.dtype,
572                NumericOperationIr::GreaterEqual(desc),
573            ))
574            .output()
575    }
576
577    fn float_greater_equal_elem(
578        lhs: FloatTensor<Self>,
579        rhs: Scalar,
580        out_dtype: BoolDType,
581    ) -> BoolTensor<Self> {
582        let client = lhs.client.clone();
583        let rhs = rhs.into();
584        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
585            client.create_empty_handle()
586        });
587
588        client
589            .register(OperationIr::NumericFloat(
590                desc.lhs.dtype,
591                NumericOperationIr::GreaterEqualElem(desc),
592            ))
593            .output()
594    }
595
596    fn float_lower(
597        lhs: FloatTensor<Self>,
598        rhs: FloatTensor<Self>,
599        out_dtype: BoolDType,
600    ) -> BoolTensor<Self> {
601        let client = lhs.client.clone();
602        let desc =
603            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
604                client.create_empty_handle()
605            });
606
607        client
608            .register(OperationIr::NumericFloat(
609                desc.lhs.dtype,
610                NumericOperationIr::Lower(desc),
611            ))
612            .output()
613    }
614
615    fn float_lower_elem(
616        lhs: FloatTensor<Self>,
617        rhs: Scalar,
618        out_dtype: BoolDType,
619    ) -> BoolTensor<Self> {
620        let client = lhs.client.clone();
621        let rhs = rhs.into();
622        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
623            client.create_empty_handle()
624        });
625
626        client
627            .register(OperationIr::NumericFloat(
628                desc.lhs.dtype,
629                NumericOperationIr::LowerElem(desc),
630            ))
631            .output()
632    }
633
634    fn float_lower_equal(
635        lhs: FloatTensor<Self>,
636        rhs: FloatTensor<Self>,
637        out_dtype: BoolDType,
638    ) -> BoolTensor<Self> {
639        let client = lhs.client.clone();
640        let desc =
641            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
642                client.create_empty_handle()
643            });
644
645        client
646            .register(OperationIr::NumericFloat(
647                desc.lhs.dtype,
648                NumericOperationIr::LowerEqual(desc),
649            ))
650            .output()
651    }
652
653    fn float_lower_equal_elem(
654        lhs: FloatTensor<Self>,
655        rhs: Scalar,
656        out_dtype: BoolDType,
657    ) -> BoolTensor<Self> {
658        let client = lhs.client.clone();
659        let rhs = rhs.into();
660        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
661            client.create_empty_handle()
662        });
663
664        client
665            .register(OperationIr::NumericFloat(
666                desc.lhs.dtype,
667                NumericOperationIr::LowerEqualElem(desc),
668            ))
669            .output()
670    }
671
672    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
673        let client = tensor.client.clone();
674        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
675
676        client
677            .register(OperationIr::NumericFloat(
678                desc.out.dtype,
679                NumericOperationIr::Sum(desc),
680            ))
681            .output()
682    }
683
684    fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
685        let client = tensor.client.clone();
686        let desc =
687            ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
688
689        client
690            .register(OperationIr::NumericFloat(
691                desc.out.dtype,
692                NumericOperationIr::SumDim(desc),
693            ))
694            .output()
695    }
696
697    fn float_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
698        let client = tensor.client.clone();
699        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
700
701        client
702            .register(OperationIr::NumericFloat(
703                desc.out.dtype,
704                NumericOperationIr::Prod(desc),
705            ))
706            .output()
707    }
708
709    fn float_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
710        let client = tensor.client.clone();
711        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
712
713        client
714            .register(OperationIr::NumericFloat(
715                desc.out.dtype,
716                NumericOperationIr::ProdDim(desc),
717            ))
718            .output()
719    }
720
721    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
722        let client = tensor.client.clone();
723        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
724
725        client
726            .register(OperationIr::NumericFloat(
727                desc.out.dtype,
728                NumericOperationIr::Mean(desc),
729            ))
730            .output()
731    }
732
733    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
734        let client = tensor.client.clone();
735        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
736
737        client
738            .register(OperationIr::NumericFloat(
739                desc.out.dtype,
740                NumericOperationIr::MeanDim(desc),
741            ))
742            .output()
743    }
744
745    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
746        let client = tensor.client.clone();
747        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
748
749        client
750            .register(OperationIr::NumericFloat(
751                desc.out.dtype,
752                NumericOperationIr::CumSum(desc),
753            ))
754            .output()
755    }
756
757    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
758        let client = tensor.client.clone();
759        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
760
761        client
762            .register(OperationIr::NumericFloat(
763                desc.out.dtype,
764                NumericOperationIr::CumProd(desc),
765            ))
766            .output()
767    }
768
769    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
770        let client = tensor.client.clone();
771        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
772
773        client
774            .register(OperationIr::NumericFloat(
775                desc.out.dtype,
776                NumericOperationIr::CumMin(desc),
777            ))
778            .output()
779    }
780
781    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
782        let client = tensor.client.clone();
783        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
784
785        client
786            .register(OperationIr::NumericFloat(
787                desc.out.dtype,
788                NumericOperationIr::CumMax(desc),
789            ))
790            .output()
791    }
792
793    fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
794        let client = lhs.client.clone();
795        let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());
796
797        client
798            .register(OperationIr::Float(
799                desc.out.dtype,
800                FloatOperationIr::Exp(desc),
801            ))
802            .output()
803    }
804
805    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
806        let client = tensor.client.clone();
807        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
808
809        client
810            .register(OperationIr::Float(
811                desc.out.dtype,
812                FloatOperationIr::Log(desc),
813            ))
814            .output()
815    }
816
817    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
818        let client = tensor.client.clone();
819        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
820
821        client
822            .register(OperationIr::Float(
823                desc.out.dtype,
824                FloatOperationIr::Log1p(desc),
825            ))
826            .output()
827    }
828
829    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
830        let client = lhs.client.clone();
831        let rhs = rhs.into();
832        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
833
834        client
835            .register(OperationIr::Float(
836                desc.out.dtype,
837                FloatOperationIr::PowfScalar(desc),
838            ))
839            .output()
840    }
841
842    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
843        let client = tensor.client.clone();
844        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
845
846        client
847            .register(OperationIr::Float(
848                desc.out.dtype,
849                FloatOperationIr::Sqrt(desc),
850            ))
851            .output()
852    }
853
854    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
855        let client = tensor.client.clone();
856        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
857
858        client
859            .register(OperationIr::NumericFloat(
860                desc.out.dtype,
861                NumericOperationIr::Abs(desc),
862            ))
863            .output()
864    }
865
866    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
867        let client = tensor.client.clone();
868        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
869
870        client
871            .register(OperationIr::Float(
872                desc.out.dtype,
873                FloatOperationIr::Cos(desc),
874            ))
875            .output()
876    }
877
878    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
879        let client = tensor.client.clone();
880        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
881
882        client
883            .register(OperationIr::Float(
884                desc.out.dtype,
885                FloatOperationIr::Cosh(desc),
886            ))
887            .output()
888    }
889
890    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
891        let client = tensor.client.clone();
892        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
893
894        client
895            .register(OperationIr::Float(
896                desc.out.dtype,
897                FloatOperationIr::Sin(desc),
898            ))
899            .output()
900    }
901
902    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
903        let client = tensor.client.clone();
904        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
905
906        client
907            .register(OperationIr::Float(
908                desc.out.dtype,
909                FloatOperationIr::Sinh(desc),
910            ))
911            .output()
912    }
913
914    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
915        let client = tensor.client.clone();
916        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
917
918        client
919            .register(OperationIr::Float(
920                desc.out.dtype,
921                FloatOperationIr::Tan(desc),
922            ))
923            .output()
924    }
925
926    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
927        let client = tensor.client.clone();
928        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
929
930        client
931            .register(OperationIr::Float(
932                desc.out.dtype,
933                FloatOperationIr::Tanh(desc),
934            ))
935            .output()
936    }
937
938    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
939        let client = tensor.client.clone();
940        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
941
942        client
943            .register(OperationIr::Float(
944                desc.out.dtype,
945                FloatOperationIr::ArcCos(desc),
946            ))
947            .output()
948    }
949
950    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
951        let client = tensor.client.clone();
952        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
953
954        client
955            .register(OperationIr::Float(
956                desc.out.dtype,
957                FloatOperationIr::ArcCosh(desc),
958            ))
959            .output()
960    }
961
962    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
963        let client = tensor.client.clone();
964        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
965
966        client
967            .register(OperationIr::Float(
968                desc.out.dtype,
969                FloatOperationIr::ArcSin(desc),
970            ))
971            .output()
972    }
973
974    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
975        let client = tensor.client.clone();
976        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
977
978        client
979            .register(OperationIr::Float(
980                desc.out.dtype,
981                FloatOperationIr::ArcSinh(desc),
982            ))
983            .output()
984    }
985
986    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
987        let client = tensor.client.clone();
988        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
989
990        client
991            .register(OperationIr::Float(
992                desc.out.dtype,
993                FloatOperationIr::ArcTan(desc),
994            ))
995            .output()
996    }
997
998    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
999        let client = tensor.client.clone();
1000        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1001
1002        client
1003            .register(OperationIr::Float(
1004                desc.out.dtype,
1005                FloatOperationIr::ArcTanh(desc),
1006            ))
1007            .output()
1008    }
1009
1010    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1011        let client = lhs.client.clone();
1012        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1013            client.create_empty_handle()
1014        });
1015
1016        client
1017            .register(OperationIr::Float(
1018                desc.out.dtype,
1019                FloatOperationIr::ArcTan2(desc),
1020            ))
1021            .output()
1022    }
1023
1024    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1025        let client = tensor.client.clone();
1026        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1027
1028        client
1029            .register(OperationIr::Float(
1030                desc.out.dtype,
1031                FloatOperationIr::Round(desc),
1032            ))
1033            .output()
1034    }
1035
1036    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1037        let client = tensor.client.clone();
1038        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1039
1040        client
1041            .register(OperationIr::Float(
1042                desc.out.dtype,
1043                FloatOperationIr::Floor(desc),
1044            ))
1045            .output()
1046    }
1047
1048    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1049        let client = tensor.client.clone();
1050        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1051
1052        client
1053            .register(OperationIr::Float(
1054                desc.out.dtype,
1055                FloatOperationIr::Ceil(desc),
1056            ))
1057            .output()
1058    }
1059
1060    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1061        let client = tensor.client.clone();
1062        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1063
1064        client
1065            .register(OperationIr::Float(
1066                desc.out.dtype,
1067                FloatOperationIr::Trunc(desc),
1068            ))
1069            .output()
1070    }
1071
1072    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1073        let client = tensor.client.clone();
1074        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1075
1076        client
1077            .register(OperationIr::Float(
1078                desc.out.dtype,
1079                FloatOperationIr::Recip(desc),
1080            ))
1081            .output()
1082    }
1083
1084    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1085        let client = tensor.client.clone();
1086        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1087
1088        client
1089            .register(OperationIr::Float(
1090                desc.out.dtype,
1091                FloatOperationIr::Erf(desc),
1092            ))
1093            .output()
1094    }
1095
1096    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1097        let client = tensors.first().unwrap().client.clone();
1098        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
1099        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
1100
1101        client
1102            .register(OperationIr::BaseFloat(BaseOperationIr::Cat(desc)))
1103            .output()
1104    }
1105
1106    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
1107        let client = tensor.client.clone();
1108        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
1109            client.create_empty_handle()
1110        });
1111
1112        client
1113            .register(OperationIr::NumericFloat(
1114                desc.input.dtype,
1115                NumericOperationIr::ArgMax(desc),
1116            ))
1117            .output()
1118    }
1119
1120    fn float_argtopk(
1121        tensor: FloatTensor<Self>,
1122        dim: usize,
1123        k: usize,
1124        out_dtype: IntDType,
1125    ) -> IntTensor<Self> {
1126        let client = tensor.client.clone();
1127        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, k, out_dtype.into(), || {
1128            client.create_empty_handle()
1129        });
1130
1131        client
1132            .register(OperationIr::NumericFloat(
1133                desc.input.dtype,
1134                NumericOperationIr::ArgTopK(desc),
1135            ))
1136            .output()
1137    }
1138
1139    fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
1140        let client = tensor.client.clone();
1141        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1142
1143        client
1144            .register(OperationIr::NumericFloat(
1145                desc.input.dtype,
1146                NumericOperationIr::TopK(desc),
1147            ))
1148            .output()
1149    }
1150
1151    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1152        let client = tensor.client.clone();
1153        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1154            client.create_empty_handle()
1155        });
1156
1157        client
1158            .register(OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc)))
1159            .output()
1160    }
1161
1162    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
1163        let client = tensor.client.clone();
1164        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
1165            client.create_empty_handle()
1166        });
1167
1168        client
1169            .register(OperationIr::NumericFloat(
1170                desc.input.dtype,
1171                NumericOperationIr::ArgMin(desc),
1172            ))
1173            .output()
1174    }
1175
1176    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1177        let client = tensor.client.clone();
1178        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1179
1180        client
1181            .register(OperationIr::NumericFloat(
1182                desc.out.dtype,
1183                NumericOperationIr::Max(desc),
1184            ))
1185            .output()
1186    }
1187
1188    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1189        let client = tensor.client.clone();
1190        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1191
1192        client
1193            .register(OperationIr::NumericFloat(
1194                desc.out.dtype,
1195                NumericOperationIr::MaxDim(desc),
1196            ))
1197            .output()
1198    }
1199
1200    fn float_max_dim_with_indices(
1201        tensor: FloatTensor<Self>,
1202        dim: usize,
1203        indices_dtype: IntDType,
1204    ) -> (FloatTensor<Self>, IntTensor<Self>) {
1205        let client = tensor.client.clone();
1206        let desc =
1207            ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
1208                client.create_empty_handle()
1209            });
1210
1211        client
1212            .register(OperationIr::NumericFloat(
1213                desc.tensor.dtype,
1214                NumericOperationIr::MaxDimWithIndices(desc),
1215            ))
1216            .outputs()
1217            .into()
1218    }
1219
1220    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1221        let client = tensor.client.clone();
1222        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1223
1224        client
1225            .register(OperationIr::NumericFloat(
1226                desc.out.dtype,
1227                NumericOperationIr::Min(desc),
1228            ))
1229            .output()
1230    }
1231
1232    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1233        let client = tensor.client.clone();
1234        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1235
1236        client
1237            .register(OperationIr::NumericFloat(
1238                desc.out.dtype,
1239                NumericOperationIr::MinDim(desc),
1240            ))
1241            .output()
1242    }
1243
1244    fn float_min_dim_with_indices(
1245        tensor: FloatTensor<Self>,
1246        dim: usize,
1247        indices_dtype: IntDType,
1248    ) -> (FloatTensor<Self>, IntTensor<Self>) {
1249        let client = tensor.client.clone();
1250        let desc =
1251            ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
1252                client.create_empty_handle()
1253            });
1254
1255        client
1256            .register(OperationIr::NumericFloat(
1257                desc.tensor.dtype,
1258                NumericOperationIr::MinDimWithIndices(desc),
1259            ))
1260            .outputs()
1261            .into()
1262    }
1263
1264    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1265        let client = lhs.client.clone();
1266        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1267            client.create_empty_handle()
1268        });
1269
1270        client
1271            .register(OperationIr::Float(
1272                desc.out.dtype,
1273                FloatOperationIr::Powf(desc),
1274            ))
1275            .output()
1276    }
1277
1278    fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
1279        let client = lhs.client.clone();
1280        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1281            client.create_empty_handle()
1282        });
1283
1284        client
1285            .register(OperationIr::NumericFloat(
1286                desc.out.dtype,
1287                NumericOperationIr::Powi(desc),
1288            ))
1289            .output()
1290    }
1291
1292    fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
1293        let client = lhs.client.clone();
1294        let rhs = rhs.into();
1295        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1296
1297        client
1298            .register(OperationIr::NumericFloat(
1299                desc.out.dtype,
1300                NumericOperationIr::PowiScalar(desc),
1301            ))
1302            .output()
1303    }
1304
1305    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1306        let client = tensor.client.clone();
1307        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
1308            client.create_empty_handle()
1309        });
1310
1311        client
1312            .register(OperationIr::BaseFloat(BaseOperationIr::Permute(desc)))
1313            .output()
1314    }
1315
1316    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
1317        let client = tensor.client.clone();
1318        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
1319
1320        client
1321            .register(OperationIr::BaseFloat(BaseOperationIr::Expand(desc)))
1322            .output()
1323    }
1324
1325    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1326        let client = tensor.client.clone();
1327        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
1328            client.create_empty_handle()
1329        });
1330
1331        client
1332            .register(OperationIr::BaseFloat(BaseOperationIr::Flip(desc)))
1333            .output()
1334    }
1335
1336    fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {
1337        let client = tensor.client.clone();
1338        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1339            client.create_empty_handle()
1340        });
1341
1342        client
1343            .register(OperationIr::BaseFloat(BaseOperationIr::Cast(desc)))
1344            .output()
1345    }
1346
1347    fn float_unfold(
1348        tensor: FloatTensor<Self>,
1349        dim: usize,
1350        size: usize,
1351        step: usize,
1352    ) -> FloatTensor<Self> {
1353        let client = tensor.client.clone();
1354        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1355            client.create_empty_handle()
1356        });
1357
1358        client
1359            .register(OperationIr::BaseFloat(BaseOperationIr::Unfold(desc)))
1360            .output()
1361    }
1362}