Skip to main content

burn_router/ops/
tensor.rs

1use alloc::vec::Vec;
2use burn_backend::backend::{Backend, ExecutionError};
3
4use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
5use burn_backend::tensor::{
6    BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
7};
8use burn_backend::{
9    Distribution, Element, FloatDType, Shape, Slice, TensorData, ops::FloatTensorOps,
10};
11use burn_ir::{
12    BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, CrossOpIr, DimOpIr,
13    FlipOpIr, FloatOperationIr, FullOpIr, GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr,
14    MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr,
15    ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarIr, ScalarOpIr,
16    ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr,
17    UnaryOpIr, UnfoldOpIr,
18};
19
20impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
21    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
22        let client = get_client::<R>(device);
23        let out = client.register_tensor_data(data);
24        let desc = InitOperationIr {
25            out: out.to_ir_out(),
26        };
27
28        // Call register op when output is already initialized
29        client.register_op(OperationIr::Init(desc));
30
31        out
32    }
33
34    fn float_random(
35        shape: Shape,
36        distribution: Distribution,
37        device: &Device<Self>,
38    ) -> FloatTensor<Self> {
39        let client = get_client::<R>(device);
40        let dtype = FloatElem::<Self>::dtype();
41        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
42
43        client
44            .register(OperationIr::Float(dtype, FloatOperationIr::Random(desc)))
45            .output()
46    }
47
48    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
49        let client = get_client::<R>(device);
50        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
51
52        client
53            .register(OperationIr::BaseFloat(BaseOperationIr::Zeros(desc)))
54            .output()
55    }
56
57    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
58        let client = get_client::<R>(device);
59        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
60
61        client
62            .register(OperationIr::BaseFloat(BaseOperationIr::Ones(desc)))
63            .output()
64    }
65
66    fn float_full(
67        shape: Shape,
68        fill_value: FloatElem<Self>,
69        device: &Device<Self>,
70        dtype: FloatDType,
71    ) -> FloatTensor<Self> {
72        let client = get_client::<R>(device);
73        let dtype = dtype.into();
74        let value = ScalarIr::with_dtype(fill_value, &dtype);
75        let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
76
77        client
78            .register(OperationIr::NumericFloat(
79                desc.out.dtype,
80                NumericOperationIr::Full(desc),
81            ))
82            .output()
83    }
84
85    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
86        Ok(tensor
87            .into_data()
88            .await?
89            // Since underlying backends can have different data types, we convert to the current elem
90            .convert::<<Self as Backend>::FloatElem>())
91    }
92
93    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
94        tensor.client.device()
95    }
96
97    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
98        if &tensor.client.device() == device {
99            return tensor;
100        }
101        R::change_client_backend(tensor, device)
102    }
103
104    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
105        let client = tensor.client.clone();
106        let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {
107            client.create_empty_handle()
108        });
109
110        client
111            .register(OperationIr::Float(
112                desc.input.dtype,
113                FloatOperationIr::IntoInt(desc),
114            ))
115            .output()
116    }
117
118    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
119        let client = get_client::<R>(device);
120        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
121
122        client
123            .register(OperationIr::BaseFloat(BaseOperationIr::Empty(desc)))
124            .output()
125    }
126
127    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
128        let client = lhs.client.clone();
129        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
130            client.create_empty_handle()
131        });
132
133        client
134            .register(OperationIr::NumericFloat(
135                desc.out.dtype,
136                NumericOperationIr::Add(desc),
137            ))
138            .output()
139    }
140
141    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
142        let client = lhs.client.clone();
143        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
144        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
145
146        client
147            .register(OperationIr::NumericFloat(
148                desc.out.dtype,
149                NumericOperationIr::AddScalar(desc),
150            ))
151            .output()
152    }
153
154    fn float_clamp(
155        tensor: FloatTensor<Self>,
156        min: FloatElem<Self>,
157        max: FloatElem<Self>,
158    ) -> FloatTensor<Self> {
159        let client = tensor.client.clone();
160        let min = ScalarIr::with_dtype(min, &tensor.dtype);
161        let max = ScalarIr::with_dtype(max, &tensor.dtype);
162        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
163
164        client
165            .register(OperationIr::NumericFloat(
166                desc.out.dtype,
167                NumericOperationIr::Clamp(desc),
168            ))
169            .output()
170    }
171
172    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
173        let client = lhs.client.clone();
174        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
175            client.create_empty_handle()
176        });
177
178        client
179            .register(OperationIr::NumericFloat(
180                desc.out.dtype,
181                NumericOperationIr::Sub(desc),
182            ))
183            .output()
184    }
185
186    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
187        let client = lhs.client.clone();
188        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
189        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
190
191        client
192            .register(OperationIr::NumericFloat(
193                desc.out.dtype,
194                NumericOperationIr::SubScalar(desc),
195            ))
196            .output()
197    }
198
199    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
200        let client = lhs.client.clone();
201        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
202            client.create_empty_handle()
203        });
204
205        client
206            .register(OperationIr::NumericFloat(
207                desc.out.dtype,
208                NumericOperationIr::Mul(desc),
209            ))
210            .output()
211    }
212
213    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
214        let client = lhs.client.clone();
215        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
216        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
217
218        client
219            .register(OperationIr::NumericFloat(
220                desc.out.dtype,
221                NumericOperationIr::MulScalar(desc),
222            ))
223            .output()
224    }
225
226    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
227        let client = lhs.client.clone();
228        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
229            client.create_empty_handle()
230        });
231
232        client
233            .register(OperationIr::NumericFloat(
234                desc.out.dtype,
235                NumericOperationIr::Div(desc),
236            ))
237            .output()
238    }
239
240    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
241        let client = lhs.client.clone();
242        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
243        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
244
245        client
246            .register(OperationIr::NumericFloat(
247                desc.out.dtype,
248                NumericOperationIr::DivScalar(desc),
249            ))
250            .output()
251    }
252
253    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
254        let client = lhs.client.clone();
255        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
256            client.create_empty_handle()
257        });
258
259        client
260            .register(OperationIr::NumericFloat(
261                desc.out.dtype,
262                NumericOperationIr::Rem(desc),
263            ))
264            .output()
265    }
266
267    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
268        let client = lhs.client.clone();
269        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
270        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
271
272        client
273            .register(OperationIr::NumericFloat(
274                desc.out.dtype,
275                NumericOperationIr::RemScalar(desc),
276            ))
277            .output()
278    }
279
280    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
281        let client = lhs.client.clone();
282        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
283            client.create_empty_handle()
284        });
285
286        client
287            .register(OperationIr::Float(
288                desc.out.dtype,
289                FloatOperationIr::Matmul(desc),
290            ))
291            .output()
292    }
293
294    fn float_cross(
295        lhs: FloatTensor<Self>,
296        rhs: FloatTensor<Self>,
297        dim: usize,
298    ) -> FloatTensor<Self> {
299        let client = lhs.client.clone();
300        let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {
301            client.create_empty_handle()
302        });
303
304        client
305            .register(OperationIr::Float(
306                desc.out.dtype,
307                FloatOperationIr::Cross(desc),
308            ))
309            .output()
310    }
311
312    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
313        let client = tensor.client.clone();
314        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
315            client.create_empty_handle()
316        });
317
318        client
319            .register(OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc)))
320            .output()
321    }
322
323    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
324        let client = tensor.client.clone();
325        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
326
327        client
328            .register(OperationIr::BaseFloat(BaseOperationIr::Reshape(desc)))
329            .output()
330    }
331
332    fn float_gather(
333        dim: usize,
334        tensor: FloatTensor<Self>,
335        indices: IntTensor<Self>,
336    ) -> FloatTensor<Self> {
337        let client = tensor.client.clone();
338        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
339            client.create_empty_handle()
340        });
341
342        client
343            .register(OperationIr::BaseFloat(BaseOperationIr::Gather(desc)))
344            .output()
345    }
346
347    fn float_scatter_add(
348        dim: usize,
349        tensor: FloatTensor<Self>,
350        indices: IntTensor<Self>,
351        value: FloatTensor<Self>,
352    ) -> FloatTensor<Self> {
353        let client = tensor.client.clone();
354        let desc = ScatterOpIr::create(
355            tensor.into_ir(),
356            dim,
357            indices.into_ir(),
358            value.into_ir(),
359            IndexingUpdateOp::Add,
360            || client.create_empty_handle(),
361        );
362
363        client
364            .register(OperationIr::BaseFloat(BaseOperationIr::Scatter(desc)))
365            .output()
366    }
367
368    fn float_select(
369        tensor: FloatTensor<Self>,
370        dim: usize,
371        indices: IntTensor<Self>,
372    ) -> FloatTensor<Self> {
373        let client = tensor.client.clone();
374        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
375            client.create_empty_handle()
376        });
377
378        client
379            .register(OperationIr::BaseFloat(BaseOperationIr::Select(desc)))
380            .output()
381    }
382
383    fn float_select_add(
384        tensor: FloatTensor<Self>,
385        dim: usize,
386        indices: IntTensor<Self>,
387        value: FloatTensor<Self>,
388    ) -> FloatTensor<Self> {
389        let client = tensor.client.clone();
390        let desc = SelectAssignOpIr::create(
391            tensor.into_ir(),
392            dim,
393            indices.into_ir(),
394            value.into_ir(),
395            IndexingUpdateOp::Add,
396            || client.create_empty_handle(),
397        );
398
399        client
400            .register(OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc)))
401            .output()
402    }
403
404    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
405        let client = tensor.client.clone();
406        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
407            client.create_empty_handle()
408        });
409
410        client
411            .register(OperationIr::BaseFloat(BaseOperationIr::Slice(desc)))
412            .output()
413    }
414
415    fn float_slice_assign(
416        tensor: FloatTensor<Self>,
417        slices: &[burn_backend::Slice],
418        value: FloatTensor<Self>,
419    ) -> FloatTensor<Self> {
420        let client = tensor.client.clone();
421        let desc =
422            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
423                client.create_empty_handle()
424            });
425
426        client
427            .register(OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc)))
428            .output()
429    }
430
431    fn float_mask_where(
432        tensor: FloatTensor<Self>,
433        mask: BoolTensor<Self>,
434        value: FloatTensor<Self>,
435    ) -> FloatTensor<Self> {
436        let client = tensor.client.clone();
437        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
438            client.create_empty_handle()
439        });
440
441        client
442            .register(OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc)))
443            .output()
444    }
445
446    fn float_mask_fill(
447        tensor: FloatTensor<Self>,
448        mask: BoolTensor<Self>,
449        value: FloatElem<Self>,
450    ) -> FloatTensor<Self> {
451        let client = tensor.client.clone();
452        let value = ScalarIr::with_dtype(value, &tensor.dtype);
453        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
454            client.create_empty_handle()
455        });
456
457        client
458            .register(OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc)))
459            .output()
460    }
461
462    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
463        let client = lhs.client.clone();
464        let desc = BinaryOpIr::create_comparison(
465            lhs.into_ir(),
466            rhs.into_ir(),
467            R::BoolElem::dtype(),
468            || client.create_empty_handle(),
469        );
470
471        client
472            .register(OperationIr::BaseFloat(BaseOperationIr::Equal(desc)))
473            .output()
474    }
475
476    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
477        let client = lhs.client.clone();
478        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
479        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
480            client.create_empty_handle()
481        });
482
483        client
484            .register(OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc)))
485            .output()
486    }
487
488    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
489        let client = lhs.client.clone();
490        let desc = BinaryOpIr::create_comparison(
491            lhs.into_ir(),
492            rhs.into_ir(),
493            R::BoolElem::dtype(),
494            || client.create_empty_handle(),
495        );
496
497        client
498            .register(OperationIr::NumericFloat(
499                desc.lhs.dtype,
500                NumericOperationIr::Greater(desc),
501            ))
502            .output()
503    }
504
505    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
506        let client = lhs.client.clone();
507        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
508        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
509            client.create_empty_handle()
510        });
511
512        client
513            .register(OperationIr::NumericFloat(
514                desc.lhs.dtype,
515                NumericOperationIr::GreaterElem(desc),
516            ))
517            .output()
518    }
519
520    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
521        let client = lhs.client.clone();
522        let desc = BinaryOpIr::create_comparison(
523            lhs.into_ir(),
524            rhs.into_ir(),
525            R::BoolElem::dtype(),
526            || client.create_empty_handle(),
527        );
528
529        client
530            .register(OperationIr::NumericFloat(
531                desc.lhs.dtype,
532                NumericOperationIr::GreaterEqual(desc),
533            ))
534            .output()
535    }
536
537    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
538        let client = lhs.client.clone();
539        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
540        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
541            client.create_empty_handle()
542        });
543
544        client
545            .register(OperationIr::NumericFloat(
546                desc.lhs.dtype,
547                NumericOperationIr::GreaterEqualElem(desc),
548            ))
549            .output()
550    }
551
552    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
553        let client = lhs.client.clone();
554        let desc = BinaryOpIr::create_comparison(
555            lhs.into_ir(),
556            rhs.into_ir(),
557            R::BoolElem::dtype(),
558            || client.create_empty_handle(),
559        );
560
561        client
562            .register(OperationIr::NumericFloat(
563                desc.lhs.dtype,
564                NumericOperationIr::Lower(desc),
565            ))
566            .output()
567    }
568
569    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
570        let client = lhs.client.clone();
571        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
572        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
573            client.create_empty_handle()
574        });
575
576        client
577            .register(OperationIr::NumericFloat(
578                desc.lhs.dtype,
579                NumericOperationIr::LowerElem(desc),
580            ))
581            .output()
582    }
583
584    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
585        let client = lhs.client.clone();
586        let desc = BinaryOpIr::create_comparison(
587            lhs.into_ir(),
588            rhs.into_ir(),
589            R::BoolElem::dtype(),
590            || client.create_empty_handle(),
591        );
592
593        client
594            .register(OperationIr::NumericFloat(
595                desc.lhs.dtype,
596                NumericOperationIr::LowerEqual(desc),
597            ))
598            .output()
599    }
600
601    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
602        let client = lhs.client.clone();
603        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
604        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
605            client.create_empty_handle()
606        });
607
608        client
609            .register(OperationIr::NumericFloat(
610                desc.lhs.dtype,
611                NumericOperationIr::LowerEqualElem(desc),
612            ))
613            .output()
614    }
615
616    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
617        let client = tensor.client.clone();
618        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
619
620        client
621            .register(OperationIr::NumericFloat(
622                desc.out.dtype,
623                NumericOperationIr::Sum(desc),
624            ))
625            .output()
626    }
627
628    fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
629        let client = tensor.client.clone();
630        let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());
631
632        client
633            .register(OperationIr::NumericFloat(
634                desc.out.dtype,
635                NumericOperationIr::SumDim(desc),
636            ))
637            .output()
638    }
639
640    fn float_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
641        let client = tensor.client.clone();
642        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
643
644        client
645            .register(OperationIr::NumericFloat(
646                desc.out.dtype,
647                NumericOperationIr::Prod(desc),
648            ))
649            .output()
650    }
651
652    fn float_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
653        let client = tensor.client.clone();
654        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
655
656        client
657            .register(OperationIr::NumericFloat(
658                desc.out.dtype,
659                NumericOperationIr::ProdDim(desc),
660            ))
661            .output()
662    }
663
664    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
665        let client = tensor.client.clone();
666        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
667
668        client
669            .register(OperationIr::NumericFloat(
670                desc.out.dtype,
671                NumericOperationIr::Mean(desc),
672            ))
673            .output()
674    }
675
676    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
677        let client = tensor.client.clone();
678        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
679
680        client
681            .register(OperationIr::NumericFloat(
682                desc.out.dtype,
683                NumericOperationIr::MeanDim(desc),
684            ))
685            .output()
686    }
687
688    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
689        let client = tensor.client.clone();
690        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
691
692        client
693            .register(OperationIr::NumericFloat(
694                desc.out.dtype,
695                NumericOperationIr::CumSum(desc),
696            ))
697            .output()
698    }
699
700    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
701        let client = tensor.client.clone();
702        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
703
704        client
705            .register(OperationIr::NumericFloat(
706                desc.out.dtype,
707                NumericOperationIr::CumProd(desc),
708            ))
709            .output()
710    }
711
712    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
713        let client = tensor.client.clone();
714        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
715
716        client
717            .register(OperationIr::NumericFloat(
718                desc.out.dtype,
719                NumericOperationIr::CumMin(desc),
720            ))
721            .output()
722    }
723
724    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
725        let client = tensor.client.clone();
726        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
727
728        client
729            .register(OperationIr::NumericFloat(
730                desc.out.dtype,
731                NumericOperationIr::CumMax(desc),
732            ))
733            .output()
734    }
735
736    fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
737        let client = lhs.client.clone();
738        let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());
739
740        client
741            .register(OperationIr::Float(
742                desc.out.dtype,
743                FloatOperationIr::Exp(desc),
744            ))
745            .output()
746    }
747
748    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
749        let client = tensor.client.clone();
750        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
751
752        client
753            .register(OperationIr::Float(
754                desc.out.dtype,
755                FloatOperationIr::Log(desc),
756            ))
757            .output()
758    }
759
760    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
761        let client = tensor.client.clone();
762        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
763
764        client
765            .register(OperationIr::Float(
766                desc.out.dtype,
767                FloatOperationIr::Log1p(desc),
768            ))
769            .output()
770    }
771
772    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
773        let client = lhs.client.clone();
774        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
775        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
776
777        client
778            .register(OperationIr::Float(
779                desc.out.dtype,
780                FloatOperationIr::PowfScalar(desc),
781            ))
782            .output()
783    }
784
785    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
786        let client = tensor.client.clone();
787        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
788
789        client
790            .register(OperationIr::Float(
791                desc.out.dtype,
792                FloatOperationIr::Sqrt(desc),
793            ))
794            .output()
795    }
796
797    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
798        let client = tensor.client.clone();
799        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
800
801        client
802            .register(OperationIr::NumericFloat(
803                desc.out.dtype,
804                NumericOperationIr::Abs(desc),
805            ))
806            .output()
807    }
808
809    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
810        let client = tensor.client.clone();
811        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
812
813        client
814            .register(OperationIr::Float(
815                desc.out.dtype,
816                FloatOperationIr::Cos(desc),
817            ))
818            .output()
819    }
820
821    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
822        let client = tensor.client.clone();
823        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
824
825        client
826            .register(OperationIr::Float(
827                desc.out.dtype,
828                FloatOperationIr::Cosh(desc),
829            ))
830            .output()
831    }
832
833    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
834        let client = tensor.client.clone();
835        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
836
837        client
838            .register(OperationIr::Float(
839                desc.out.dtype,
840                FloatOperationIr::Sin(desc),
841            ))
842            .output()
843    }
844
845    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
846        let client = tensor.client.clone();
847        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
848
849        client
850            .register(OperationIr::Float(
851                desc.out.dtype,
852                FloatOperationIr::Sinh(desc),
853            ))
854            .output()
855    }
856
857    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
858        let client = tensor.client.clone();
859        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
860
861        client
862            .register(OperationIr::Float(
863                desc.out.dtype,
864                FloatOperationIr::Tan(desc),
865            ))
866            .output()
867    }
868
869    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
870        let client = tensor.client.clone();
871        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
872
873        client
874            .register(OperationIr::Float(
875                desc.out.dtype,
876                FloatOperationIr::Tanh(desc),
877            ))
878            .output()
879    }
880
881    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
882        let client = tensor.client.clone();
883        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
884
885        client
886            .register(OperationIr::Float(
887                desc.out.dtype,
888                FloatOperationIr::ArcCos(desc),
889            ))
890            .output()
891    }
892
893    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
894        let client = tensor.client.clone();
895        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
896
897        client
898            .register(OperationIr::Float(
899                desc.out.dtype,
900                FloatOperationIr::ArcCosh(desc),
901            ))
902            .output()
903    }
904
905    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
906        let client = tensor.client.clone();
907        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
908
909        client
910            .register(OperationIr::Float(
911                desc.out.dtype,
912                FloatOperationIr::ArcSin(desc),
913            ))
914            .output()
915    }
916
917    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
918        let client = tensor.client.clone();
919        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
920
921        client
922            .register(OperationIr::Float(
923                desc.out.dtype,
924                FloatOperationIr::ArcSinh(desc),
925            ))
926            .output()
927    }
928
929    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
930        let client = tensor.client.clone();
931        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
932
933        client
934            .register(OperationIr::Float(
935                desc.out.dtype,
936                FloatOperationIr::ArcTan(desc),
937            ))
938            .output()
939    }
940
941    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
942        let client = tensor.client.clone();
943        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
944
945        client
946            .register(OperationIr::Float(
947                desc.out.dtype,
948                FloatOperationIr::ArcTanh(desc),
949            ))
950            .output()
951    }
952
953    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
954        let client = lhs.client.clone();
955        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
956            client.create_empty_handle()
957        });
958
959        client
960            .register(OperationIr::Float(
961                desc.out.dtype,
962                FloatOperationIr::ArcTan2(desc),
963            ))
964            .output()
965    }
966
967    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
968        let client = tensor.client.clone();
969        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
970
971        client
972            .register(OperationIr::Float(
973                desc.out.dtype,
974                FloatOperationIr::Round(desc),
975            ))
976            .output()
977    }
978
979    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
980        let client = tensor.client.clone();
981        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
982
983        client
984            .register(OperationIr::Float(
985                desc.out.dtype,
986                FloatOperationIr::Floor(desc),
987            ))
988            .output()
989    }
990
991    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
992        let client = tensor.client.clone();
993        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
994
995        client
996            .register(OperationIr::Float(
997                desc.out.dtype,
998                FloatOperationIr::Ceil(desc),
999            ))
1000            .output()
1001    }
1002
1003    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1004        let client = tensor.client.clone();
1005        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1006
1007        client
1008            .register(OperationIr::Float(
1009                desc.out.dtype,
1010                FloatOperationIr::Trunc(desc),
1011            ))
1012            .output()
1013    }
1014
1015    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1016        let client = tensor.client.clone();
1017        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1018
1019        client
1020            .register(OperationIr::Float(
1021                desc.out.dtype,
1022                FloatOperationIr::Recip(desc),
1023            ))
1024            .output()
1025    }
1026
1027    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1028        let client = tensor.client.clone();
1029        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1030
1031        client
1032            .register(OperationIr::Float(
1033                desc.out.dtype,
1034                FloatOperationIr::Erf(desc),
1035            ))
1036            .output()
1037    }
1038
1039    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1040        let client = tensors.first().unwrap().client.clone();
1041        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
1042        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
1043
1044        client
1045            .register(OperationIr::BaseFloat(BaseOperationIr::Cat(desc)))
1046            .output()
1047    }
1048
1049    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1050        let client = tensor.client.clone();
1051        let desc =
1052            ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::<Self>::dtype(), || {
1053                client.create_empty_handle()
1054            });
1055
1056        client
1057            .register(OperationIr::NumericFloat(
1058                desc.input.dtype,
1059                NumericOperationIr::ArgMax(desc),
1060            ))
1061            .output()
1062    }
1063
1064    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1065        let client = tensor.client.clone();
1066        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1067            client.create_empty_handle()
1068        });
1069
1070        client
1071            .register(OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc)))
1072            .output()
1073    }
1074
1075    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1076        let client = tensor.client.clone();
1077        let desc =
1078            ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::<Self>::dtype(), || {
1079                client.create_empty_handle()
1080            });
1081
1082        client
1083            .register(OperationIr::NumericFloat(
1084                desc.input.dtype,
1085                NumericOperationIr::ArgMin(desc),
1086            ))
1087            .output()
1088    }
1089
1090    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1091        let client = tensor.client.clone();
1092        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1093
1094        client
1095            .register(OperationIr::NumericFloat(
1096                desc.out.dtype,
1097                NumericOperationIr::Max(desc),
1098            ))
1099            .output()
1100    }
1101
1102    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1103        let client = tensor.client.clone();
1104        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1105
1106        client
1107            .register(OperationIr::NumericFloat(
1108                desc.out.dtype,
1109                NumericOperationIr::MaxDim(desc),
1110            ))
1111            .output()
1112    }
1113
1114    fn float_max_dim_with_indices(
1115        tensor: FloatTensor<Self>,
1116        dim: usize,
1117    ) -> (FloatTensor<Self>, IntTensor<Self>) {
1118        let client = tensor.client.clone();
1119        let desc = ReduceDimWithIndicesOpIr::create(
1120            tensor.into_ir(),
1121            dim,
1122            IntElem::<Self>::dtype(),
1123            || client.create_empty_handle(),
1124        );
1125
1126        client
1127            .register(OperationIr::NumericFloat(
1128                desc.tensor.dtype,
1129                NumericOperationIr::MaxDimWithIndices(desc),
1130            ))
1131            .outputs()
1132            .into()
1133    }
1134
1135    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1136        let client = tensor.client.clone();
1137        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1138
1139        client
1140            .register(OperationIr::NumericFloat(
1141                desc.out.dtype,
1142                NumericOperationIr::Min(desc),
1143            ))
1144            .output()
1145    }
1146
1147    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1148        let client = tensor.client.clone();
1149        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1150
1151        client
1152            .register(OperationIr::NumericFloat(
1153                desc.out.dtype,
1154                NumericOperationIr::MinDim(desc),
1155            ))
1156            .output()
1157    }
1158
1159    fn float_min_dim_with_indices(
1160        tensor: FloatTensor<Self>,
1161        dim: usize,
1162    ) -> (FloatTensor<Self>, IntTensor<Self>) {
1163        let client = tensor.client.clone();
1164        let desc = ReduceDimWithIndicesOpIr::create(
1165            tensor.into_ir(),
1166            dim,
1167            IntElem::<Self>::dtype(),
1168            || client.create_empty_handle(),
1169        );
1170
1171        client
1172            .register(OperationIr::NumericFloat(
1173                desc.tensor.dtype,
1174                NumericOperationIr::MinDimWithIndices(desc),
1175            ))
1176            .outputs()
1177            .into()
1178    }
1179
1180    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1181        let client = lhs.client.clone();
1182        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1183            client.create_empty_handle()
1184        });
1185
1186        client
1187            .register(OperationIr::NumericFloat(
1188                desc.out.dtype,
1189                NumericOperationIr::Powf(desc),
1190            ))
1191            .output()
1192    }
1193
1194    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1195        let client = tensor.client.clone();
1196        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
1197            client.create_empty_handle()
1198        });
1199
1200        client
1201            .register(OperationIr::BaseFloat(BaseOperationIr::Permute(desc)))
1202            .output()
1203    }
1204
1205    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
1206        let client = tensor.client.clone();
1207        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
1208
1209        client
1210            .register(OperationIr::BaseFloat(BaseOperationIr::Expand(desc)))
1211            .output()
1212    }
1213
1214    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1215        let client = tensor.client.clone();
1216        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
1217            client.create_empty_handle()
1218        });
1219
1220        client
1221            .register(OperationIr::BaseFloat(BaseOperationIr::Flip(desc)))
1222            .output()
1223    }
1224
1225    fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {
1226        let client = tensor.client.clone();
1227        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1228            client.create_empty_handle()
1229        });
1230
1231        client
1232            .register(OperationIr::BaseFloat(BaseOperationIr::Cast(desc)))
1233            .output()
1234    }
1235
1236    fn float_unfold(
1237        tensor: FloatTensor<Self>,
1238        dim: usize,
1239        size: usize,
1240        step: usize,
1241    ) -> FloatTensor<Self> {
1242        let client = tensor.client.clone();
1243        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1244            client.create_empty_handle()
1245        });
1246
1247        client
1248            .register(OperationIr::BaseFloat(BaseOperationIr::Unfold(desc)))
1249            .output()
1250    }
1251}