Skip to main content

burn_fusion/ops/
tensor.rs

1use super::NoOp;
2use crate::{
3    Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops,
4    client::GlobalFusionClient,
5    get_client, reduce_float_ops, reduce_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
6    stream::{OperationStreams, execution::Operation},
7    unary_float_ops,
8};
9use burn_backend::{
10    BoolDType, Distribution, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice,
11    TensorData,
12    ops::{FloatTensorOps, GridSampleOptions},
13    tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
14};
15use burn_ir::*;
16use std::marker::PhantomData;
17
18impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
19    #[cfg_attr(feature = "tracing", tracing::instrument(
20        level="trace",
21        skip(data),
22        fields(?data.shape, ?data.dtype)
23    ))]
24    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
25        let client = get_client::<B>(device);
26        let dtype = data.dtype;
27        let tensor = B::float_from_data(data, device);
28        let shape = burn_backend::TensorMetadata::shape(&tensor);
29
30        let handle = B::float_tensor_handle(tensor);
31        let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
32
33        client
34            .register(
35                OperationStreams::default(),
36                OperationIr::Init(desc),
37                NoOp::<B>::new(),
38            )
39            .output()
40    }
41
42    fn float_random(
43        shape: Shape,
44        distribution: Distribution,
45        device: &Device<Self>,
46        dtype: FloatDType,
47    ) -> FloatTensor<Self> {
48        #[derive(new, Debug)]
49        struct RandomOps<B: FusionBackend> {
50            desc: RandomOpIr,
51            device: Device<B>,
52        }
53
54        impl<B: FusionBackend> Operation<B::FusionRuntime> for RandomOps<B> {
55            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
56                let output: B::FloatTensorPrimitive = B::float_random(
57                    self.desc.out.shape.clone(),
58                    self.desc.distribution,
59                    &self.device,
60                    self.desc.out.dtype.into(),
61                );
62                handles.register_float_tensor::<B>(&self.desc.out.id, output);
63            }
64        }
65
66        let dtype = dtype.into();
67        let client = get_client::<B>(device);
68        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
69
70        client
71            .register(
72                OperationStreams::default(),
73                OperationIr::Float(dtype, FloatOperationIr::Random(desc.clone())),
74                RandomOps::<B>::new(desc, device.clone()),
75            )
76            .output()
77    }
78
79    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
80        #[derive(new, Debug)]
81        struct ZerosOps<B: FusionBackend> {
82            out: TensorIr,
83            device: Device<B>,
84        }
85
86        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
87            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
88                let shape = self.out.shape.clone();
89                let output = B::float_zeros(shape, &self.device, self.out.dtype.into());
90                handles.register_float_tensor::<B>(&self.out.id, output);
91            }
92        }
93
94        let client = get_client::<B>(device);
95        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
96
97        client
98            .register(
99                OperationStreams::default(),
100                OperationIr::BaseFloat(BaseOperationIr::Zeros(desc.clone())),
101                ZerosOps::<B>::new(desc.out, device.clone()),
102            )
103            .output()
104    }
105
106    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
107        #[derive(new, Debug)]
108        struct OnesOps<B: FusionBackend> {
109            out: TensorIr,
110            device: Device<B>,
111        }
112
113        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
114            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
115                let shape = self.out.shape.clone();
116                let output = B::float_ones(shape, &self.device, self.out.dtype.into());
117                handles.register_float_tensor::<B>(&self.out.id, output);
118            }
119        }
120
121        let client = get_client::<B>(device);
122        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
123
124        client
125            .register(
126                OperationStreams::default(),
127                OperationIr::BaseFloat(BaseOperationIr::Ones(desc.clone())),
128                OnesOps::<B>::new(desc.out, device.clone()),
129            )
130            .output()
131    }
132
133    fn float_full(
134        shape: Shape,
135        fill_value: Scalar,
136        device: &Device<Self>,
137        dtype: FloatDType,
138    ) -> FloatTensor<Self> {
139        #[derive(new, Debug)]
140        struct FullOps<B: FusionBackend> {
141            out: TensorIr,
142            elem: ScalarIr,
143            device: Device<B>,
144        }
145
146        impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
147            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
148                let shape = self.out.shape.clone();
149                let dtype = self.out.dtype.into();
150                let output: B::FloatTensorPrimitive =
151                    B::float_full(shape, self.elem.into(), &self.device, dtype);
152                handles.register_float_tensor::<B>(&self.out.id, output);
153            }
154        }
155
156        let dtype = dtype.into();
157        let client = get_client::<B>(device);
158        let value = fill_value.into();
159        let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
160
161        client
162            .register(
163                OperationStreams::default(),
164                OperationIr::NumericFloat(dtype, NumericOperationIr::Full(desc.clone())),
165                FullOps::<B>::new(desc.out, desc.value, device.clone()),
166            )
167            .output()
168    }
169
170    #[cfg_attr(feature = "tracing", tracing::instrument(
171        level="trace",
172        skip(tensor),
173        fields(
174            from = ?tensor.client.device(),
175            shape = ?tensor.shape,
176            dtype = ?tensor.dtype
177        )
178    ))]
179    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
180        tensor.into_data::<B>().await
181    }
182
183    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
184        tensor.client.device().clone()
185    }
186
187    #[cfg_attr(feature = "tracing", tracing::instrument(
188        level="trace",
189        skip(tensor),
190        fields(
191            from = ?tensor.client.device(),
192            shape = ?tensor.shape,
193            dtype = ?tensor.dtype,
194        )
195    ))]
196    fn float_to_device(tensor: FloatTensor<Self>, device_dst: &Device<Self>) -> FloatTensor<Self> {
197        let device_src: &B::Device = tensor.client.device();
198
199        if device_src == device_dst {
200            return tensor;
201        }
202
203        let id = tensor.stream;
204        let client_dst = get_client::<B>(device_dst);
205        let client_src = tensor.client.clone();
206
207        GlobalFusionClient::change_client_float::<B>(tensor.into_ir(), client_src, client_dst, id)
208    }
209
210    fn float_into_int(tensor: FloatTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
211        #[derive(new, Debug)]
212        struct IntoIntOps<B: FusionBackend> {
213            desc: CastOpIr,
214            _b: PhantomData<B>,
215        }
216
217        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
218            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
219                let input = handles.get_float_tensor::<B>(&self.desc.input);
220                let output = B::float_into_int(input, self.desc.out.dtype.into());
221
222                handles.register_int_tensor::<B>(&self.desc.out.id, output);
223            }
224        }
225
226        let streams = OperationStreams::with_inputs([&tensor]);
227
228        let client = tensor.client.clone();
229        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
230            client.create_empty_handle()
231        });
232
233        client
234            .register(
235                streams,
236                OperationIr::Float(desc.input.dtype, FloatOperationIr::IntoInt(desc.clone())),
237                IntoIntOps::<B>::new(desc),
238            )
239            .output()
240    }
241
242    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
243        #[derive(new, Debug)]
244        struct EmptyOps<B: FusionBackend> {
245            desc: TensorIr,
246            device: Device<B>,
247        }
248
249        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
250            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
251                let output = B::float_empty(
252                    self.desc.shape.clone(),
253                    &self.device,
254                    self.desc.dtype.into(),
255                );
256                handles.register_float_tensor::<B>(&self.desc.id, output);
257            }
258        }
259
260        let client = get_client::<B>(device);
261        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
262
263        client
264            .register(
265                OperationStreams::default(),
266                OperationIr::BaseFloat(BaseOperationIr::Empty(desc.clone())),
267                EmptyOps::<B>::new(desc.out, device.clone()),
268            )
269            .output()
270    }
271
272    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
273        binary_float_ops!(AddOps, B::float_add);
274
275        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
276
277        let client = lhs.client.clone();
278        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
279            client.create_empty_handle()
280        });
281
282        client
283            .register(
284                streams,
285                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Add(desc.clone())),
286                AddOps::<B>::new(desc),
287            )
288            .output()
289    }
290
291    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
292        scalar_float_ops!(AddOps, B::float_add_scalar);
293
294        let streams = OperationStreams::with_inputs([&lhs]);
295
296        let client = lhs.client.clone();
297        let rhs = rhs.into();
298        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
299
300        client
301            .register(
302                streams,
303                OperationIr::NumericFloat(
304                    desc.out.dtype,
305                    NumericOperationIr::AddScalar(desc.clone()),
306                ),
307                AddOps::<B>::new(desc),
308            )
309            .output()
310    }
311
312    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
313        #[derive(new, Debug)]
314        struct ClampOps<B: FusionBackend> {
315            desc: ClampOpIr,
316            _b: PhantomData<B>,
317        }
318
319        impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
320            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
321                let input = handles.get_float_tensor::<B>(&self.desc.tensor);
322                let output = B::float_clamp(input, self.desc.min.into(), self.desc.max.into());
323
324                handles.register_float_tensor::<B>(&self.desc.out.id, output);
325            }
326        }
327
328        let streams = OperationStreams::with_inputs([&tensor]);
329
330        let client = tensor.client.clone();
331        let min = min.into();
332        let max = max.into();
333        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
334
335        client
336            .register(
337                streams,
338                OperationIr::NumericFloat(
339                    desc.tensor.dtype,
340                    NumericOperationIr::Clamp(desc.clone()),
341                ),
342                ClampOps::<B>::new(desc),
343            )
344            .output()
345    }
346
347    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
348        binary_float_ops!(SubOps, B::float_sub);
349
350        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
351
352        let client = lhs.client.clone();
353        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
354            client.create_empty_handle()
355        });
356
357        client
358            .register(
359                streams,
360                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sub(desc.clone())),
361                SubOps::<B>::new(desc),
362            )
363            .output()
364    }
365
366    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
367        scalar_float_ops!(SubOps, B::float_sub_scalar);
368
369        let streams = OperationStreams::with_inputs([&lhs]);
370
371        let client = lhs.client.clone();
372        let rhs = rhs.into();
373        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
374
375        client
376            .register(
377                streams,
378                OperationIr::NumericFloat(
379                    desc.out.dtype,
380                    NumericOperationIr::SubScalar(desc.clone()),
381                ),
382                SubOps::<B>::new(desc),
383            )
384            .output()
385    }
386
387    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
388        binary_float_ops!(MulOps, B::float_mul);
389
390        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
391
392        let client = lhs.client.clone();
393        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
394            client.create_empty_handle()
395        });
396
397        client
398            .register(
399                streams,
400                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mul(desc.clone())),
401                MulOps::<B>::new(desc),
402            )
403            .output()
404    }
405
406    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
407        scalar_float_ops!(MulOps, B::float_mul_scalar);
408
409        let streams = OperationStreams::with_inputs([&lhs]);
410
411        let client = lhs.client.clone();
412        let rhs = rhs.into();
413        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
414
415        client
416            .register(
417                streams,
418                OperationIr::NumericFloat(
419                    desc.out.dtype,
420                    NumericOperationIr::MulScalar(desc.clone()),
421                ),
422                MulOps::<B>::new(desc),
423            )
424            .output()
425    }
426
427    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
428        binary_float_ops!(DivOps, B::float_div);
429
430        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
431
432        let client = lhs.client.clone();
433        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
434            client.create_empty_handle()
435        });
436
437        client
438            .register(
439                streams,
440                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Div(desc.clone())),
441                DivOps::<B>::new(desc),
442            )
443            .output()
444    }
445
446    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
447        scalar_float_ops!(DivOps, B::float_div_scalar);
448
449        let streams = OperationStreams::with_inputs([&lhs]);
450
451        let client = lhs.client.clone();
452        let rhs = rhs.into();
453        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
454
455        client
456            .register(
457                streams,
458                OperationIr::NumericFloat(
459                    desc.out.dtype,
460                    NumericOperationIr::DivScalar(desc.clone()),
461                ),
462                DivOps::<B>::new(desc),
463            )
464            .output()
465    }
466
467    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
468        binary_float_ops!(ModOps, B::float_remainder);
469
470        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
471
472        let client = lhs.client.clone();
473        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
474            client.create_empty_handle()
475        });
476
477        client
478            .register(
479                streams,
480                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Rem(desc.clone())),
481                ModOps::<B>::new(desc),
482            )
483            .output()
484    }
485
486    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
487        scalar_float_ops!(ModOps, B::float_remainder_scalar);
488
489        let streams = OperationStreams::with_inputs([&lhs]);
490
491        let client = lhs.client.clone();
492        let rhs = rhs.into();
493        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
494
495        client
496            .register(
497                streams,
498                OperationIr::NumericFloat(
499                    desc.out.dtype,
500                    NumericOperationIr::RemScalar(desc.clone()),
501                ),
502                ModOps::<B>::new(desc),
503            )
504            .output()
505    }
506
507    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
508        binary_float_ops!(MatmulOps, B::float_matmul);
509
510        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
511
512        let client = lhs.client.clone();
513        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
514            client.create_empty_handle()
515        });
516
517        client
518            .register(
519                streams,
520                OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())),
521                MatmulOps::<B>::new(desc.into()),
522            )
523            .output()
524    }
525
526    fn float_cross(
527        lhs: FloatTensor<Self>,
528        rhs: FloatTensor<Self>,
529        dim: usize,
530    ) -> FloatTensor<Self> {
531        #[derive(new, Debug)]
532        struct CrossOps<B: FusionBackend> {
533            desc: CrossOpIr,
534            _b: PhantomData<B>,
535        }
536
537        impl<B: FusionBackend> Operation<B::FusionRuntime> for CrossOps<B> {
538            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
539                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
540                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
541                let output = B::float_cross(lhs, rhs, self.desc.dim);
542                handles.register_float_tensor::<B>(&self.desc.out.id, output);
543            }
544        }
545
546        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
547
548        let client = lhs.client.clone();
549        let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {
550            client.create_empty_handle()
551        });
552
553        client
554            .register(
555                streams,
556                OperationIr::Float(desc.out.dtype, FloatOperationIr::Cross(desc.clone())),
557                CrossOps::<B>::new(desc),
558            )
559            .output()
560    }
561
562    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
563        #[derive(new, Debug)]
564        struct SwapDimsOps<B: FusionBackend> {
565            desc: SwapDimsOpIr,
566            _b: PhantomData<B>,
567        }
568
569        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
570            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
571                let input = handles.get_float_tensor::<B>(&self.desc.input);
572                let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);
573                handles.register_float_tensor::<B>(&self.desc.out.id, output);
574            }
575        }
576
577        let streams = OperationStreams::with_inputs([&tensor]);
578
579        let client = tensor.client.clone();
580        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
581            client.create_empty_handle()
582        });
583
584        client
585            .register(
586                streams,
587                OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
588                SwapDimsOps::<B>::new(desc),
589            )
590            .output()
591    }
592
593    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
594        if tensor.shape == shape {
595            return tensor;
596        }
597
598        #[derive(new, Debug)]
599        struct ReshapeDimsOps<B: FusionBackend> {
600            desc: ShapeOpIr,
601            _b: PhantomData<B>,
602        }
603
604        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
605            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
606                let input = handles.get_float_tensor::<B>(&self.desc.input);
607                let output = B::float_reshape(input, self.desc.out.shape.clone());
608                handles.register_float_tensor::<B>(&self.desc.out.id, output);
609            }
610        }
611
612        let streams = OperationStreams::with_inputs([&tensor]);
613
614        let client = tensor.client.clone();
615        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
616
617        client
618            .register(
619                streams,
620                OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
621                ReshapeDimsOps::<B>::new(desc),
622            )
623            .output()
624    }
625
626    fn float_gather(
627        dim: usize,
628        tensor: FloatTensor<Self>,
629        indices: IntTensor<Self>,
630    ) -> FloatTensor<Self> {
631        #[derive(new, Debug)]
632        struct GatherOps<B: FusionBackend> {
633            desc: GatherOpIr,
634            _b: PhantomData<B>,
635        }
636
637        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
638            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
639                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
640                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
641
642                let output = B::float_gather(self.desc.dim, tensor, indices);
643                handles.register_float_tensor::<B>(&self.desc.out.id, output);
644            }
645        }
646
647        let streams = OperationStreams::with_inputs([&tensor, &indices]);
648
649        let client = tensor.client.clone();
650        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
651            client.create_empty_handle()
652        });
653
654        client
655            .register(
656                streams,
657                OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),
658                GatherOps::<B>::new(desc),
659            )
660            .output()
661    }
662
663    fn float_scatter_add(
664        dim: usize,
665        tensor: FloatTensor<Self>,
666        indices: IntTensor<Self>,
667        value: FloatTensor<Self>,
668    ) -> FloatTensor<Self> {
669        #[derive(new, Debug)]
670        struct ScatterOps<B: FusionBackend> {
671            desc: ScatterOpIr,
672            _b: PhantomData<B>,
673        }
674
675        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
676            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
677                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
678                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
679                let value = handles.get_float_tensor::<B>(&self.desc.value);
680
681                let output = B::float_scatter_add(self.desc.dim, tensor, indices, value);
682
683                handles.register_float_tensor::<B>(&self.desc.out.id, output);
684            }
685        }
686
687        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
688
689        let client = tensor.client.clone();
690        let desc = ScatterOpIr::create(
691            tensor.into_ir(),
692            dim,
693            indices.into_ir(),
694            value.into_ir(),
695            IndexingUpdateOp::Add,
696            || client.create_empty_handle(),
697        );
698
699        client
700            .register(
701                streams,
702                OperationIr::BaseFloat(BaseOperationIr::Scatter(desc.clone())),
703                ScatterOps::<B>::new(desc),
704            )
705            .output()
706    }
707
708    fn float_scatter_nd(
709        data: FloatTensor<Self>,
710        indices: IntTensor<Self>,
711        values: FloatTensor<Self>,
712        reduction: IndexingUpdateOp,
713    ) -> FloatTensor<Self> {
714        #[derive(new, Debug)]
715        struct ScatterNdOps<B: FusionBackend> {
716            desc: ScatterNdOpIr,
717            _b: PhantomData<B>,
718        }
719
720        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterNdOps<B> {
721            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
722                let data = handles.get_float_tensor::<B>(&self.desc.data);
723                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
724                let values = handles.get_float_tensor::<B>(&self.desc.values);
725
726                let output = B::float_scatter_nd(data, indices, values, self.desc.reduction);
727
728                handles.register_float_tensor::<B>(&self.desc.out.id, output);
729            }
730        }
731
732        let streams = OperationStreams::with_inputs([&data, &indices, &values]);
733
734        let client = data.client.clone();
735        let desc = ScatterNdOpIr::create(
736            data.into_ir(),
737            indices.into_ir(),
738            values.into_ir(),
739            reduction,
740            || client.create_empty_handle(),
741        );
742
743        client
744            .register(
745                streams,
746                OperationIr::BaseFloat(BaseOperationIr::ScatterNd(desc.clone())),
747                ScatterNdOps::<B>::new(desc),
748            )
749            .output()
750    }
751
752    fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
753        #[derive(new, Debug)]
754        struct GatherNdOps<B: FusionBackend> {
755            desc: GatherNdOpIr,
756            _b: PhantomData<B>,
757        }
758
759        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherNdOps<B> {
760            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
761                let data = handles.get_float_tensor::<B>(&self.desc.data);
762                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
763
764                let output = B::float_gather_nd(data, indices);
765                handles.register_float_tensor::<B>(&self.desc.out.id, output);
766            }
767        }
768
769        let streams = OperationStreams::with_inputs([&data, &indices]);
770
771        let client = data.client.clone();
772        let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
773            client.create_empty_handle()
774        });
775
776        client
777            .register(
778                streams,
779                OperationIr::BaseFloat(BaseOperationIr::GatherNd(desc.clone())),
780                GatherNdOps::<B>::new(desc),
781            )
782            .output()
783    }
784
785    fn float_select(
786        tensor: FloatTensor<Self>,
787        dim: usize,
788        indices: IntTensor<Self>,
789    ) -> FloatTensor<Self> {
790        #[derive(new, Debug)]
791        struct SelectOps<B: FusionBackend> {
792            desc: SelectOpIr,
793            _b: PhantomData<B>,
794        }
795
796        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
797            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
798                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
799                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
800
801                let output = B::float_select(tensor, self.desc.dim, indices);
802
803                handles.register_float_tensor::<B>(&self.desc.out.id, output);
804            }
805        }
806
807        let streams = OperationStreams::with_inputs([&tensor, &indices]);
808
809        let client = tensor.client.clone();
810        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
811            client.create_empty_handle()
812        });
813
814        client
815            .register(
816                streams,
817                OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),
818                SelectOps::<B>::new(desc),
819            )
820            .output()
821    }
822
823    fn float_select_add(
824        tensor: FloatTensor<Self>,
825        dim: usize,
826        indices: IntTensor<Self>,
827        value: FloatTensor<Self>,
828    ) -> FloatTensor<Self> {
829        #[derive(new, Debug)]
830        struct SelectAssignOps<B: FusionBackend> {
831            desc: SelectAssignOpIr,
832            _b: PhantomData<B>,
833        }
834
835        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
836            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
837                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
838                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
839                let value = handles.get_float_tensor::<B>(&self.desc.value);
840
841                let output = B::float_select_add(tensor, self.desc.dim, indices, value);
842
843                handles.register_float_tensor::<B>(&self.desc.out.id, output);
844            }
845        }
846
847        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
848
849        let client = tensor.client.clone();
850        let desc = SelectAssignOpIr::create(
851            tensor.into_ir(),
852            dim,
853            indices.into_ir(),
854            value.into_ir(),
855            IndexingUpdateOp::Add,
856            || client.create_empty_handle(),
857        );
858
859        client
860            .register(
861                streams,
862                OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc.clone())),
863                SelectAssignOps::<B>::new(desc),
864            )
865            .output()
866    }
867
868    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
869        #[derive(new, Debug)]
870        struct SliceOps<B: FusionBackend> {
871            desc: SliceOpIr,
872            _b: PhantomData<B>,
873        }
874
875        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
876            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
877                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
878
879                let output = B::float_slice(tensor, self.desc.ranges.as_slice());
880
881                handles.register_float_tensor::<B>(&self.desc.out.id, output);
882            }
883        }
884
885        let streams = OperationStreams::with_inputs([&tensor]);
886
887        let client = tensor.client.clone();
888        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
889            client.create_empty_handle()
890        });
891
892        client
893            .register(
894                streams,
895                OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
896                SliceOps::<B>::new(desc),
897            )
898            .output()
899    }
900
901    fn float_slice_assign(
902        tensor: FloatTensor<Self>,
903        slices: &[burn_backend::Slice],
904        value: FloatTensor<Self>,
905    ) -> FloatTensor<Self> {
906        #[derive(new, Debug)]
907        struct SliceAssignOps<B: FusionBackend> {
908            desc: SliceAssignOpIr,
909            _b: PhantomData<B>,
910        }
911
912        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
913            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
914                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
915                let value = handles.get_float_tensor::<B>(&self.desc.value);
916
917                let output = B::float_slice_assign(tensor, self.desc.ranges.as_slice(), value);
918
919                handles.register_float_tensor::<B>(&self.desc.out.id, output);
920            }
921        }
922
923        let streams = OperationStreams::with_inputs([&tensor, &value]);
924
925        let client = tensor.client.clone();
926        let desc =
927            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
928                client.create_empty_handle()
929            });
930
931        client
932            .register(
933                streams,
934                OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc.clone())),
935                SliceAssignOps::<B>::new(desc),
936            )
937            .output()
938    }
939
940    fn float_mask_where(
941        tensor: FloatTensor<Self>,
942        mask: BoolTensor<Self>,
943        value: FloatTensor<Self>,
944    ) -> FloatTensor<Self> {
945        #[derive(new, Debug)]
946        struct MaskWhereOps<B: FusionBackend> {
947            desc: MaskWhereOpIr,
948            _b: PhantomData<B>,
949        }
950
951        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
952            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
953                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
954                let value = handles.get_float_tensor::<B>(&self.desc.value);
955                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
956
957                let output = B::float_mask_where(tensor, mask, value);
958
959                handles.register_float_tensor::<B>(&self.desc.out.id, output);
960            }
961        }
962
963        let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
964
965        let client = tensor.client.clone();
966        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
967            client.create_empty_handle()
968        });
969
970        client
971            .register(
972                streams,
973                OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc.clone())),
974                MaskWhereOps::<B>::new(desc),
975            )
976            .output()
977    }
978
979    fn float_mask_fill(
980        tensor: FloatTensor<Self>,
981        mask: BoolTensor<Self>,
982        value: Scalar,
983    ) -> FloatTensor<Self> {
984        #[derive(new, Debug)]
985        struct MaskFillOps<B: FusionBackend> {
986            desc: MaskFillOpIr,
987            _b: PhantomData<B>,
988        }
989
990        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
991            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
992                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
993                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
994
995                let output = B::float_mask_fill(tensor, mask, self.desc.value.into());
996
997                handles.register_float_tensor::<B>(&self.desc.out.id, output);
998            }
999        }
1000
1001        let streams = OperationStreams::with_inputs([&tensor, &mask]);
1002
1003        let client = tensor.client.clone();
1004        let value = value.into();
1005        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
1006            client.create_empty_handle()
1007        });
1008
1009        client
1010            .register(
1011                streams,
1012                OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc.clone())),
1013                MaskFillOps::<B>::new(desc),
1014            )
1015            .output()
1016    }
1017
1018    fn float_equal(
1019        lhs: FloatTensor<Self>,
1020        rhs: FloatTensor<Self>,
1021        out_dtype: BoolDType,
1022    ) -> BoolTensor<Self> {
1023        binary_float_cmp_ops!(EqualOps, B::float_equal);
1024
1025        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1026
1027        let client = lhs.client.clone();
1028        let desc =
1029            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1030                client.create_empty_handle()
1031            });
1032
1033        client
1034            .register(
1035                streams,
1036                OperationIr::BaseFloat(BaseOperationIr::Equal(desc.clone())),
1037                EqualOps::<B>::new(desc),
1038            )
1039            .output()
1040    }
1041
1042    fn float_equal_elem(
1043        lhs: FloatTensor<Self>,
1044        rhs: Scalar,
1045        out_dtype: BoolDType,
1046    ) -> BoolTensor<Self> {
1047        scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem);
1048
1049        let streams = OperationStreams::with_inputs([&lhs]);
1050
1051        let client = lhs.client.clone();
1052        let rhs = rhs.into();
1053        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1054            client.create_empty_handle()
1055        });
1056
1057        client
1058            .register(
1059                streams,
1060                OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc.clone())),
1061                EqualElemOps::<B>::new(desc),
1062            )
1063            .output()
1064    }
1065
1066    fn float_greater(
1067        lhs: FloatTensor<Self>,
1068        rhs: FloatTensor<Self>,
1069        out_dtype: BoolDType,
1070    ) -> BoolTensor<Self> {
1071        binary_float_cmp_ops!(GreaterOps, B::float_greater);
1072
1073        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1074
1075        let client = lhs.client.clone();
1076        let desc =
1077            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1078                client.create_empty_handle()
1079            });
1080
1081        client
1082            .register(
1083                streams,
1084                OperationIr::NumericFloat(
1085                    desc.lhs.dtype,
1086                    NumericOperationIr::Greater(desc.clone()),
1087                ),
1088                GreaterOps::<B>::new(desc),
1089            )
1090            .output()
1091    }
1092
1093    fn float_greater_elem(
1094        lhs: FloatTensor<Self>,
1095        rhs: Scalar,
1096        out_dtype: BoolDType,
1097    ) -> BoolTensor<Self> {
1098        scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem);
1099
1100        let streams = OperationStreams::with_inputs([&lhs]);
1101
1102        let client = lhs.client.clone();
1103        let rhs = rhs.into();
1104        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1105            client.create_empty_handle()
1106        });
1107
1108        client
1109            .register(
1110                streams,
1111                OperationIr::NumericFloat(
1112                    desc.lhs.dtype,
1113                    NumericOperationIr::GreaterElem(desc.clone()),
1114                ),
1115                GreaterElemOps::<B>::new(desc),
1116            )
1117            .output()
1118    }
1119
1120    fn float_greater_equal(
1121        lhs: FloatTensor<Self>,
1122        rhs: FloatTensor<Self>,
1123        out_dtype: BoolDType,
1124    ) -> BoolTensor<Self> {
1125        binary_float_cmp_ops!(GreaterEqualOps, B::float_greater_equal);
1126
1127        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1128
1129        let client = lhs.client.clone();
1130        let desc =
1131            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1132                client.create_empty_handle()
1133            });
1134
1135        client
1136            .register(
1137                streams,
1138                OperationIr::NumericFloat(
1139                    desc.lhs.dtype,
1140                    NumericOperationIr::GreaterEqual(desc.clone()),
1141                ),
1142                GreaterEqualOps::<B>::new(desc),
1143            )
1144            .output()
1145    }
1146
1147    fn float_greater_equal_elem(
1148        lhs: FloatTensor<Self>,
1149        rhs: Scalar,
1150        out_dtype: BoolDType,
1151    ) -> BoolTensor<Self> {
1152        scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem);
1153
1154        let streams = OperationStreams::with_inputs([&lhs]);
1155
1156        let client = lhs.client.clone();
1157        let rhs = rhs.into();
1158        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1159            client.create_empty_handle()
1160        });
1161
1162        client
1163            .register(
1164                streams,
1165                OperationIr::NumericFloat(
1166                    desc.lhs.dtype,
1167                    NumericOperationIr::GreaterEqualElem(desc.clone()),
1168                ),
1169                GreaterEqualElemOps::<B>::new(desc),
1170            )
1171            .output()
1172    }
1173
1174    fn float_lower(
1175        lhs: FloatTensor<Self>,
1176        rhs: FloatTensor<Self>,
1177        out_dtype: BoolDType,
1178    ) -> BoolTensor<Self> {
1179        binary_float_cmp_ops!(LowerOps, B::float_lower);
1180
1181        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1182
1183        let client = lhs.client.clone();
1184        let desc =
1185            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1186                client.create_empty_handle()
1187            });
1188
1189        client
1190            .register(
1191                streams,
1192                OperationIr::NumericFloat(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())),
1193                LowerOps::<B>::new(desc),
1194            )
1195            .output()
1196    }
1197
1198    fn float_lower_elem(
1199        lhs: FloatTensor<Self>,
1200        rhs: Scalar,
1201        out_dtype: BoolDType,
1202    ) -> BoolTensor<Self> {
1203        scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem);
1204
1205        let streams = OperationStreams::with_inputs([&lhs]);
1206
1207        let client = lhs.client.clone();
1208        let rhs = rhs.into();
1209        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1210            client.create_empty_handle()
1211        });
1212
1213        client
1214            .register(
1215                streams,
1216                OperationIr::NumericFloat(
1217                    desc.lhs.dtype,
1218                    NumericOperationIr::LowerElem(desc.clone()),
1219                ),
1220                LowerElemOps::<B>::new(desc),
1221            )
1222            .output()
1223    }
1224
1225    fn float_lower_equal(
1226        lhs: FloatTensor<Self>,
1227        rhs: FloatTensor<Self>,
1228        out_dtype: BoolDType,
1229    ) -> BoolTensor<Self> {
1230        binary_float_cmp_ops!(LowerEqualOps, B::float_lower_equal);
1231
1232        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1233
1234        let client = lhs.client.clone();
1235        let desc =
1236            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1237                client.create_empty_handle()
1238            });
1239
1240        client
1241            .register(
1242                streams,
1243                OperationIr::NumericFloat(
1244                    desc.lhs.dtype,
1245                    NumericOperationIr::LowerEqual(desc.clone()),
1246                ),
1247                LowerEqualOps::<B>::new(desc),
1248            )
1249            .output()
1250    }
1251
1252    fn float_lower_equal_elem(
1253        lhs: FloatTensor<Self>,
1254        rhs: Scalar,
1255        out_dtype: BoolDType,
1256    ) -> BoolTensor<Self> {
1257        scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem);
1258
1259        let streams = OperationStreams::with_inputs([&lhs]);
1260
1261        let client = lhs.client.clone();
1262        let rhs = rhs.into();
1263        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1264            client.create_empty_handle()
1265        });
1266
1267        client
1268            .register(
1269                streams,
1270                OperationIr::NumericFloat(
1271                    desc.lhs.dtype,
1272                    NumericOperationIr::LowerEqualElem(desc.clone()),
1273                ),
1274                LowerEqualElemOps::<B>::new(desc),
1275            )
1276            .output()
1277    }
1278
1279    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1280        unary_float_ops!(SumOps, B::float_sum, reduce);
1281
1282        let streams = OperationStreams::with_inputs([&tensor]);
1283
1284        let client = tensor.client.clone();
1285        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1286
1287        client
1288            .register(
1289                streams,
1290                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sum(desc.clone())),
1291                SumOps::<B>::new(desc.into()),
1292            )
1293            .output()
1294    }
1295
1296    fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
1297        reduce_float_ops!(SumDimOps, |tensor, axis, _| B::float_sum_dim(tensor, axis));
1298
1299        let streams = OperationStreams::with_inputs([&tensor]);
1300
1301        let client = tensor.client.clone();
1302        let desc =
1303            ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
1304
1305        client
1306            .register(
1307                streams,
1308                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())),
1309                SumDimOps::<B>::new(desc),
1310            )
1311            .output()
1312    }
1313
1314    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1315        unary_float_ops!(ProdOps, B::float_prod, reduce);
1316
1317        let streams = OperationStreams::with_inputs([&tensor]);
1318
1319        let client = tensor.client.clone();
1320        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1321
1322        client
1323            .register(
1324                streams,
1325                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Prod(desc.clone())),
1326                ProdOps::<B>::new(desc.into()),
1327            )
1328            .output()
1329    }
1330
1331    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1332        reduce_float_ops!(ProdDimOps, |tensor, axis, _| B::float_prod_dim(
1333            tensor, axis
1334        ));
1335
1336        let streams = OperationStreams::with_inputs([&tensor]);
1337
1338        let client = tensor.client.clone();
1339        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1340
1341        client
1342            .register(
1343                streams,
1344                OperationIr::NumericFloat(
1345                    desc.out.dtype,
1346                    NumericOperationIr::ProdDim(desc.clone()),
1347                ),
1348                ProdDimOps::<B>::new(desc),
1349            )
1350            .output()
1351    }
1352
1353    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1354        unary_float_ops!(MeanOps, B::float_mean, reduce);
1355
1356        let streams = OperationStreams::with_inputs([&tensor]);
1357
1358        let client = tensor.client.clone();
1359        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1360
1361        client
1362            .register(
1363                streams,
1364                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mean(desc.clone())),
1365                MeanOps::<B>::new(desc.into()),
1366            )
1367            .output()
1368    }
1369
1370    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1371        reduce_float_ops!(MeanDimOps, |tensor, axis, _| B::float_mean_dim(
1372            tensor, axis
1373        ));
1374
1375        let streams = OperationStreams::with_inputs([&tensor]);
1376
1377        let client = tensor.client.clone();
1378        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1379
1380        client
1381            .register(
1382                streams,
1383                OperationIr::NumericFloat(
1384                    desc.out.dtype,
1385                    NumericOperationIr::MeanDim(desc.clone()),
1386                ),
1387                MeanDimOps::<B>::new(desc),
1388            )
1389            .output()
1390    }
1391
1392    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1393        #[derive(new, Debug)]
1394        struct CumsumOps<B: FusionBackend> {
1395            desc: DimOpIr,
1396            _b: PhantomData<B>,
1397        }
1398
1399        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1400            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1401                let input = handles.get_float_tensor::<B>(&self.desc.input);
1402                let output = B::float_cumsum(input, self.desc.axis);
1403                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1404            }
1405        }
1406
1407        let streams = OperationStreams::with_inputs([&tensor]);
1408
1409        let client = tensor.client.clone();
1410        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1411
1412        client
1413            .register(
1414                streams,
1415                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())),
1416                CumsumOps::<B>::new(desc),
1417            )
1418            .output()
1419    }
1420
1421    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1422        #[derive(new, Debug)]
1423        struct CumprodOps<B: FusionBackend> {
1424            desc: DimOpIr,
1425            _b: PhantomData<B>,
1426        }
1427
1428        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1429            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1430                let input = handles.get_float_tensor::<B>(&self.desc.input);
1431                let output = B::float_cumprod(input, self.desc.axis);
1432                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1433            }
1434        }
1435
1436        let streams = OperationStreams::with_inputs([&tensor]);
1437
1438        let client = tensor.client.clone();
1439        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1440
1441        client
1442            .register(
1443                streams,
1444                OperationIr::NumericFloat(
1445                    desc.out.dtype,
1446                    NumericOperationIr::CumProd(desc.clone()),
1447                ),
1448                CumprodOps::<B>::new(desc),
1449            )
1450            .output()
1451    }
1452
1453    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1454        #[derive(new, Debug)]
1455        struct CumminOps<B: FusionBackend> {
1456            desc: DimOpIr,
1457            _b: PhantomData<B>,
1458        }
1459
1460        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1461            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1462                let input = handles.get_float_tensor::<B>(&self.desc.input);
1463                let output = B::float_cummin(input, self.desc.axis);
1464                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1465            }
1466        }
1467
1468        let streams = OperationStreams::with_inputs([&tensor]);
1469
1470        let client = tensor.client.clone();
1471        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1472
1473        client
1474            .register(
1475                streams,
1476                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())),
1477                CumminOps::<B>::new(desc),
1478            )
1479            .output()
1480    }
1481
1482    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1483        #[derive(new, Debug)]
1484        struct CummaxOps<B: FusionBackend> {
1485            desc: DimOpIr,
1486            _b: PhantomData<B>,
1487        }
1488
1489        impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1490            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1491                let input = handles.get_float_tensor::<B>(&self.desc.input);
1492                let output = B::float_cummax(input, self.desc.axis);
1493                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1494            }
1495        }
1496
1497        let streams = OperationStreams::with_inputs([&tensor]);
1498
1499        let client = tensor.client.clone();
1500        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1501
1502        client
1503            .register(
1504                streams,
1505                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())),
1506                CummaxOps::<B>::new(desc),
1507            )
1508            .output()
1509    }
1510
1511    fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
1512        unary_float_ops!(ExpOps, B::float_exp);
1513
1514        let streams = OperationStreams::with_inputs([&lhs]);
1515
1516        let client = lhs.client.clone();
1517        let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());
1518
1519        client
1520            .register(
1521                streams,
1522                OperationIr::Float(desc.out.dtype, FloatOperationIr::Exp(desc.clone())),
1523                ExpOps::<B>::new(desc),
1524            )
1525            .output()
1526    }
1527
1528    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1529        unary_float_ops!(LogOps, B::float_log);
1530
1531        let streams = OperationStreams::with_inputs([&tensor]);
1532
1533        let client = tensor.client.clone();
1534        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1535
1536        client
1537            .register(
1538                streams,
1539                OperationIr::Float(desc.out.dtype, FloatOperationIr::Log(desc.clone())),
1540                LogOps::<B>::new(desc),
1541            )
1542            .output()
1543    }
1544
1545    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1546        unary_float_ops!(Log1pOps, B::float_log1p);
1547
1548        let streams = OperationStreams::with_inputs([&tensor]);
1549
1550        let client = tensor.client.clone();
1551        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1552
1553        client
1554            .register(
1555                streams,
1556                OperationIr::Float(desc.out.dtype, FloatOperationIr::Log1p(desc.clone())),
1557                Log1pOps::<B>::new(desc),
1558            )
1559            .output()
1560    }
1561
1562    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
1563        scalar_float_ops!(PowfOps, B::float_powf_scalar_impl);
1564
1565        let streams = OperationStreams::with_inputs([&lhs]);
1566
1567        let client = lhs.client.clone();
1568        let rhs = rhs.into();
1569        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1570
1571        client
1572            .register(
1573                streams,
1574                OperationIr::Float(desc.out.dtype, FloatOperationIr::PowfScalar(desc.clone())),
1575                PowfOps::<B>::new(desc),
1576            )
1577            .output()
1578    }
1579
1580    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1581        unary_float_ops!(SqrtOps, B::float_sqrt);
1582
1583        let streams = OperationStreams::with_inputs([&tensor]);
1584
1585        let client = tensor.client.clone();
1586        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1587
1588        client
1589            .register(
1590                streams,
1591                OperationIr::Float(desc.out.dtype, FloatOperationIr::Sqrt(desc.clone())),
1592                SqrtOps::<B>::new(desc),
1593            )
1594            .output()
1595    }
1596
1597    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1598        unary_float_ops!(AbsOps, B::float_abs);
1599
1600        let streams = OperationStreams::with_inputs([&tensor]);
1601
1602        let client = tensor.client.clone();
1603        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1604
1605        client
1606            .register(
1607                streams,
1608                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Abs(desc.clone())),
1609                AbsOps::<B>::new(desc),
1610            )
1611            .output()
1612    }
1613
1614    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1615        unary_float_ops!(CosOps, B::float_cos);
1616
1617        let streams = OperationStreams::with_inputs([&tensor]);
1618
1619        let client = tensor.client.clone();
1620        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1621
1622        client
1623            .register(
1624                streams,
1625                OperationIr::Float(desc.out.dtype, FloatOperationIr::Cos(desc.clone())),
1626                CosOps::<B>::new(desc),
1627            )
1628            .output()
1629    }
1630
1631    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1632        unary_float_ops!(SinOps, B::float_sin);
1633
1634        let streams = OperationStreams::with_inputs([&tensor]);
1635
1636        let client = tensor.client.clone();
1637        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1638
1639        client
1640            .register(
1641                streams,
1642                OperationIr::Float(desc.out.dtype, FloatOperationIr::Sin(desc.clone())),
1643                SinOps::<B>::new(desc),
1644            )
1645            .output()
1646    }
1647
1648    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1649        unary_float_ops!(TanOps, B::float_tan);
1650
1651        let streams = OperationStreams::with_inputs([&tensor]);
1652
1653        let client = tensor.client.clone();
1654        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1655
1656        client
1657            .register(
1658                streams,
1659                OperationIr::Float(desc.out.dtype, FloatOperationIr::Tan(desc.clone())),
1660                TanOps::<B>::new(desc),
1661            )
1662            .output()
1663    }
1664
1665    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1666        unary_float_ops!(CoshOps, B::float_cosh);
1667
1668        let streams = OperationStreams::with_inputs([&tensor]);
1669
1670        let client = tensor.client.clone();
1671        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1672
1673        client
1674            .register(
1675                streams,
1676                OperationIr::Float(desc.out.dtype, FloatOperationIr::Cosh(desc.clone())),
1677                CoshOps::<B>::new(desc),
1678            )
1679            .output()
1680    }
1681
1682    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1683        unary_float_ops!(SinhOps, B::float_sinh);
1684
1685        let streams = OperationStreams::with_inputs([&tensor]);
1686
1687        let client = tensor.client.clone();
1688        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1689
1690        client
1691            .register(
1692                streams,
1693                OperationIr::Float(desc.out.dtype, FloatOperationIr::Sinh(desc.clone())),
1694                SinhOps::<B>::new(desc),
1695            )
1696            .output()
1697    }
1698
1699    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1700        unary_float_ops!(TanhOps, B::float_tanh);
1701
1702        let streams = OperationStreams::with_inputs([&tensor]);
1703
1704        let client = tensor.client.clone();
1705        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1706
1707        client
1708            .register(
1709                streams,
1710                OperationIr::Float(desc.out.dtype, FloatOperationIr::Tanh(desc.clone())),
1711                TanhOps::<B>::new(desc),
1712            )
1713            .output()
1714    }
1715
1716    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1717        unary_float_ops!(ArcCosOps, B::float_acos);
1718
1719        let streams = OperationStreams::with_inputs([&tensor]);
1720
1721        let client = tensor.client.clone();
1722        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1723
1724        client
1725            .register(
1726                streams,
1727                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCos(desc.clone())),
1728                ArcCosOps::<B>::new(desc),
1729            )
1730            .output()
1731    }
1732
1733    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1734        unary_float_ops!(ArcCoshOps, B::float_acosh);
1735
1736        let streams = OperationStreams::with_inputs([&tensor]);
1737
1738        let client = tensor.client.clone();
1739        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1740
1741        client
1742            .register(
1743                streams,
1744                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCosh(desc.clone())),
1745                ArcCoshOps::<B>::new(desc),
1746            )
1747            .output()
1748    }
1749
1750    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1751        unary_float_ops!(ArcSinOps, B::float_asin);
1752
1753        let streams = OperationStreams::with_inputs([&tensor]);
1754
1755        let client = tensor.client.clone();
1756        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1757
1758        client
1759            .register(
1760                streams,
1761                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSin(desc.clone())),
1762                ArcSinOps::<B>::new(desc),
1763            )
1764            .output()
1765    }
1766
1767    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1768        unary_float_ops!(ArcSinhOps, B::float_asinh);
1769
1770        let streams = OperationStreams::with_inputs([&tensor]);
1771
1772        let client = tensor.client.clone();
1773        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1774
1775        client
1776            .register(
1777                streams,
1778                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSinh(desc.clone())),
1779                ArcSinhOps::<B>::new(desc),
1780            )
1781            .output()
1782    }
1783
1784    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1785        unary_float_ops!(ArcTanOps, B::float_atan);
1786
1787        let streams = OperationStreams::with_inputs([&tensor]);
1788
1789        let client = tensor.client.clone();
1790        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1791
1792        client
1793            .register(
1794                streams,
1795                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan(desc.clone())),
1796                ArcTanOps::<B>::new(desc),
1797            )
1798            .output()
1799    }
1800
1801    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1802        unary_float_ops!(ArcTanhOps, B::float_atanh);
1803
1804        let streams = OperationStreams::with_inputs([&tensor]);
1805
1806        let client = tensor.client.clone();
1807        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1808
1809        client
1810            .register(
1811                streams,
1812                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTanh(desc.clone())),
1813                ArcTanhOps::<B>::new(desc),
1814            )
1815            .output()
1816    }
1817
1818    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1819        binary_float_ops!(ArcTan2Ops, B::float_atan2);
1820
1821        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1822
1823        let client = lhs.client.clone();
1824        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1825            client.create_empty_handle()
1826        });
1827
1828        client
1829            .register(
1830                streams,
1831                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan2(desc.clone())),
1832                ArcTan2Ops::<B>::new(desc),
1833            )
1834            .output()
1835    }
1836
1837    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1838        unary_float_ops!(Recip, B::float_recip);
1839
1840        let streams = OperationStreams::with_inputs([&tensor]);
1841
1842        let client = tensor.client.clone();
1843        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1844
1845        client
1846            .register(
1847                streams,
1848                OperationIr::Float(desc.out.dtype, FloatOperationIr::Recip(desc.clone())),
1849                Recip::<B>::new(desc),
1850            )
1851            .output()
1852    }
1853
1854    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1855        unary_float_ops!(TanhOps, B::float_erf);
1856
1857        let streams = OperationStreams::with_inputs([&tensor]);
1858
1859        let client = tensor.client.clone();
1860        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1861
1862        client
1863            .register(
1864                streams,
1865                OperationIr::Float(desc.out.dtype, FloatOperationIr::Erf(desc.clone())),
1866                TanhOps::<B>::new(desc),
1867            )
1868            .output()
1869    }
1870
1871    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1872        #[derive(new, Debug)]
1873        struct CatOps<B: FusionBackend> {
1874            desc: CatOpIr,
1875            _b: PhantomData<B>,
1876        }
1877
1878        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
1879            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1880                let tensors = self
1881                    .desc
1882                    .tensors
1883                    .iter()
1884                    .map(|tensor| handles.get_float_tensor::<B>(tensor))
1885                    .collect();
1886
1887                let output = B::float_cat(tensors, self.desc.dim);
1888
1889                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1890            }
1891        }
1892
1893        let streams = OperationStreams::with_inputs(&tensors);
1894
1895        let client = tensors.first().unwrap().client.clone();
1896        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
1897        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
1898
1899        client
1900            .register(
1901                streams,
1902                OperationIr::BaseFloat(BaseOperationIr::Cat(desc.clone())),
1903                CatOps::<B>::new(desc),
1904            )
1905            .output()
1906    }
1907
1908    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
1909        reduce_float2int_ops!(ArgMaxOps, |input, axis, _, dtype| B::float_argmax(
1910            input, axis, dtype
1911        ));
1912
1913        let streams = OperationStreams::with_inputs([&tensor]);
1914
1915        let client = tensor.client.clone();
1916        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
1917            client.create_empty_handle()
1918        });
1919
1920        client
1921            .register(
1922                streams,
1923                OperationIr::NumericFloat(
1924                    desc.input.dtype,
1925                    NumericOperationIr::ArgMax(desc.clone()),
1926                ),
1927                ArgMaxOps::<B>::new(desc),
1928            )
1929            .output()
1930    }
1931
1932    fn float_argtopk(
1933        tensor: FloatTensor<Self>,
1934        dim: usize,
1935        k: usize,
1936        out_dtype: IntDType,
1937    ) -> IntTensor<Self> {
1938        reduce_float2int_ops!(ArgTopKOps, B::float_argtopk);
1939
1940        let streams = OperationStreams::with_inputs([&tensor]);
1941
1942        let client = tensor.client.clone();
1943        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, k, out_dtype.into(), || {
1944            client.create_empty_handle()
1945        });
1946
1947        client
1948            .register(
1949                streams,
1950                OperationIr::NumericFloat(
1951                    desc.input.dtype,
1952                    NumericOperationIr::ArgTopK(desc.clone()),
1953                ),
1954                ArgTopKOps::<B>::new(desc),
1955            )
1956            .output()
1957    }
1958
1959    fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
1960        reduce_float_ops!(TopKOps, B::float_topk);
1961
1962        let streams = OperationStreams::with_inputs([&tensor]);
1963
1964        let client = tensor.client.clone();
1965        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1966
1967        client
1968            .register(
1969                streams,
1970                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::TopK(desc.clone())),
1971                TopKOps::<B>::new(desc),
1972            )
1973            .output()
1974    }
1975
1976    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1977        #[derive(new, Debug)]
1978        struct RepeatDimOps<B: FusionBackend> {
1979            desc: RepeatDimOpIr,
1980            _b: PhantomData<B>,
1981        }
1982
1983        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1984            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1985                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
1986
1987                let output = B::float_repeat_dim(tensor, self.desc.dim, self.desc.times);
1988
1989                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1990            }
1991        }
1992
1993        let streams = OperationStreams::with_inputs([&tensor]);
1994
1995        let client = tensor.client.clone();
1996        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1997            client.create_empty_handle()
1998        });
1999
2000        client
2001            .register(
2002                streams,
2003                OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc.clone())),
2004                RepeatDimOps::<B>::new(desc),
2005            )
2006            .output()
2007    }
2008
2009    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
2010        reduce_float2int_ops!(ArgMinOps, |input, axis, _, dtype| B::float_argmin(
2011            input, axis, dtype
2012        ));
2013
2014        let streams = OperationStreams::with_inputs([&tensor]);
2015
2016        let client = tensor.client.clone();
2017        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
2018            client.create_empty_handle()
2019        });
2020
2021        client
2022            .register(
2023                streams,
2024                OperationIr::NumericFloat(
2025                    desc.input.dtype,
2026                    NumericOperationIr::ArgMin(desc.clone()),
2027                ),
2028                ArgMinOps::<B>::new(desc),
2029            )
2030            .output()
2031    }
2032
2033    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2034        unary_float_ops!(MaxOps, B::float_max, reduce);
2035
2036        let streams = OperationStreams::with_inputs([&tensor]);
2037
2038        let client = tensor.client.clone();
2039        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2040
2041        client
2042            .register(
2043                streams,
2044                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Max(desc.clone())),
2045                MaxOps::<B>::new(desc.into()),
2046            )
2047            .output()
2048    }
2049
2050    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2051        reduce_float_ops!(MaxDimOps, |tensor, axis, _| B::float_max_dim(tensor, axis));
2052
2053        let streams = OperationStreams::with_inputs([&tensor]);
2054
2055        let client = tensor.client.clone();
2056        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
2057
2058        client
2059            .register(
2060                streams,
2061                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())),
2062                MaxDimOps::<B>::new(desc),
2063            )
2064            .output()
2065    }
2066
2067    fn float_max_dim_with_indices(
2068        tensor: FloatTensor<Self>,
2069        dim: usize,
2070        indices_dtype: IntDType,
2071    ) -> (FloatTensor<Self>, IntTensor<Self>) {
2072        #[derive(new, Debug)]
2073        struct MaxDimWithIndicesOps<B: FusionBackend> {
2074            desc: ReduceDimWithIndicesOpIr,
2075            _b: PhantomData<B>,
2076        }
2077
2078        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
2079            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2080                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2081                let (output, indices) = B::float_max_dim_with_indices(
2082                    tensor,
2083                    self.desc.dim,
2084                    self.desc.out_indices.dtype.into(),
2085                );
2086
2087                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2088                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2089            }
2090        }
2091
2092        let streams = OperationStreams::with_inputs([&tensor]);
2093
2094        let client = tensor.client.clone();
2095        let desc =
2096            ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
2097                client.create_empty_handle()
2098            });
2099
2100        client
2101            .register(
2102                streams,
2103                OperationIr::NumericFloat(
2104                    desc.tensor.dtype,
2105                    NumericOperationIr::MaxDimWithIndices(desc.clone()),
2106                ),
2107                MaxDimWithIndicesOps::<B>::new(desc),
2108            )
2109            .outputs()
2110            .into()
2111    }
2112
2113    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2114        unary_float_ops!(MinOps, B::float_min, reduce);
2115
2116        let streams = OperationStreams::with_inputs([&tensor]);
2117
2118        let client = tensor.client.clone();
2119        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2120
2121        client
2122            .register(
2123                streams,
2124                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Min(desc.clone())),
2125                MinOps::<B>::new(desc.into()),
2126            )
2127            .output()
2128    }
2129
2130    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2131        reduce_float_ops!(MinDimOps, |tensor, axis, _| B::float_min_dim(tensor, axis));
2132
2133        let streams = OperationStreams::with_inputs([&tensor]);
2134
2135        let client = tensor.client.clone();
2136        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
2137
2138        client
2139            .register(
2140                streams,
2141                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())),
2142                MinDimOps::<B>::new(desc),
2143            )
2144            .output()
2145    }
2146
2147    fn float_min_dim_with_indices(
2148        tensor: FloatTensor<Self>,
2149        dim: usize,
2150        indices_dtype: IntDType,
2151    ) -> (FloatTensor<Self>, IntTensor<Self>) {
2152        #[derive(new, Debug)]
2153        struct MinDimWithIndicesOps<B: FusionBackend> {
2154            desc: ReduceDimWithIndicesOpIr,
2155            _b: PhantomData<B>,
2156        }
2157
2158        impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
2159            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2160                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2161                let (output, indices) = B::float_min_dim_with_indices(
2162                    tensor,
2163                    self.desc.dim,
2164                    self.desc.out_indices.dtype.into(),
2165                );
2166
2167                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2168                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2169            }
2170        }
2171        let streams = OperationStreams::with_inputs([&tensor]);
2172
2173        let client = tensor.client.clone();
2174        let desc =
2175            ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
2176                client.create_empty_handle()
2177            });
2178
2179        client
2180            .register(
2181                streams,
2182                OperationIr::NumericFloat(
2183                    desc.tensor.dtype,
2184                    NumericOperationIr::MinDimWithIndices(desc.clone()),
2185                ),
2186                MinDimWithIndicesOps::<B>::new(desc),
2187            )
2188            .outputs()
2189            .into()
2190    }
2191
2192    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2193        unary_float_ops!(MaxAbsOps, B::float_max_abs, reduce);
2194
2195        let streams = OperationStreams::with_inputs([&tensor]);
2196
2197        let client = tensor.client.clone();
2198        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2199
2200        client
2201            .register(
2202                streams,
2203                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())),
2204                MaxAbsOps::<B>::new(desc.into()),
2205            )
2206            .output()
2207    }
2208
2209    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2210        reduce_float_ops!(MaxAbsDimOps, |tensor, axis, _| B::float_max_abs_dim(
2211            tensor, axis
2212        ));
2213
2214        let streams = OperationStreams::with_inputs([&tensor]);
2215
2216        let client = tensor.client.clone();
2217        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
2218
2219        client
2220            .register(
2221                streams,
2222                OperationIr::NumericFloat(
2223                    desc.out.dtype,
2224                    NumericOperationIr::MaxAbsDim(desc.clone()),
2225                ),
2226                MaxAbsDimOps::<B>::new(desc),
2227            )
2228            .output()
2229    }
2230
2231    // TODO: float_powi w/ burn-cubecl-fusion impl
2232    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
2233        binary_float_ops!(PowOps, B::float_powf);
2234
2235        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
2236
2237        let client = lhs.client.clone();
2238        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
2239            client.create_empty_handle()
2240        });
2241
2242        client
2243            .register(
2244                streams,
2245                OperationIr::Float(desc.out.dtype, FloatOperationIr::Powf(desc.clone())),
2246                PowOps::<B>::new(desc),
2247            )
2248            .output()
2249    }
2250
2251    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2252        #[derive(new, Debug)]
2253        struct PermuteDimsOps<B: FusionBackend> {
2254            desc: PermuteOpIr,
2255            _b: PhantomData<B>,
2256        }
2257
2258        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
2259            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2260                let input = handles.get_float_tensor::<B>(&self.desc.input);
2261                let output = B::float_permute(input, self.desc.axes.as_slice());
2262                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2263            }
2264        }
2265
2266        let streams = OperationStreams::with_inputs([&tensor]);
2267
2268        let client = tensor.client.clone();
2269        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
2270            client.create_empty_handle()
2271        });
2272
2273        client
2274            .register(
2275                streams,
2276                OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
2277                PermuteDimsOps::<B>::new(desc),
2278            )
2279            .output()
2280    }
2281
2282    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
2283        #[derive(new, Debug)]
2284        struct ExpandOps<B: FusionBackend> {
2285            desc: ShapeOpIr,
2286            _b: PhantomData<B>,
2287        }
2288
2289        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
2290            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2291                let input = handles.get_float_tensor::<B>(&self.desc.input);
2292                let output = B::float_expand(input, self.desc.out.shape.clone());
2293
2294                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2295            }
2296        }
2297
2298        let streams = OperationStreams::with_inputs([&tensor]);
2299
2300        let client = tensor.client.clone();
2301        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
2302
2303        client
2304            .register(
2305                streams,
2306                OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
2307                ExpandOps::<B>::new(desc),
2308            )
2309            .output()
2310    }
2311
2312    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2313        #[derive(new, Debug)]
2314        struct FlipOps<B: FusionBackend> {
2315            desc: FlipOpIr,
2316            _b: PhantomData<B>,
2317        }
2318
2319        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
2320            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2321                let input = handles.get_float_tensor::<B>(&self.desc.input);
2322                let output = B::float_flip(input, &self.desc.axes);
2323                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2324            }
2325        }
2326
2327        let streams = OperationStreams::with_inputs([&tensor]);
2328
2329        let client = tensor.client.clone();
2330        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
2331            client.create_empty_handle()
2332        });
2333
2334        client
2335            .register(
2336                streams,
2337                OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
2338                FlipOps::<B>::new(desc),
2339            )
2340            .output()
2341    }
2342
2343    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2344        unary_float_ops!(RoundOps, B::float_round);
2345
2346        let streams = OperationStreams::with_inputs([&tensor]);
2347
2348        let client = tensor.client.clone();
2349        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2350
2351        client
2352            .register(
2353                streams,
2354                OperationIr::Float(desc.out.dtype, FloatOperationIr::Round(desc.clone())),
2355                RoundOps::<B>::new(desc),
2356            )
2357            .output()
2358    }
2359
2360    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2361        unary_float_ops!(FloorOps, B::float_floor);
2362
2363        let streams = OperationStreams::with_inputs([&tensor]);
2364
2365        let client = tensor.client.clone();
2366        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2367
2368        client
2369            .register(
2370                streams,
2371                OperationIr::Float(desc.out.dtype, FloatOperationIr::Floor(desc.clone())),
2372                FloorOps::<B>::new(desc),
2373            )
2374            .output()
2375    }
2376
2377    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2378        unary_float_ops!(CeilOps, B::float_ceil);
2379
2380        let streams = OperationStreams::with_inputs([&tensor]);
2381
2382        let client = tensor.client.clone();
2383        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2384
2385        client
2386            .register(
2387                streams,
2388                OperationIr::Float(desc.out.dtype, FloatOperationIr::Ceil(desc.clone())),
2389                CeilOps::<B>::new(desc),
2390            )
2391            .output()
2392    }
2393
2394    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2395        unary_float_ops!(TruncOps, B::float_trunc);
2396
2397        let streams = OperationStreams::with_inputs([&tensor]);
2398
2399        let client = tensor.client.clone();
2400        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2401
2402        client
2403            .register(
2404                streams,
2405                OperationIr::Float(desc.out.dtype, FloatOperationIr::Trunc(desc.clone())),
2406                TruncOps::<B>::new(desc),
2407            )
2408            .output()
2409    }
2410
2411    fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {
2412        #[derive(new, Debug)]
2413        struct CastOps<B: FusionBackend> {
2414            desc: CastOpIr,
2415            dtype: burn_backend::FloatDType,
2416            _b: PhantomData<B>,
2417        }
2418
2419        impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2420            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2421                let input = handles.get_float_tensor::<B>(&self.desc.input);
2422                let output: B::FloatTensorPrimitive = B::float_cast(input, self.dtype);
2423                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2424            }
2425        }
2426
2427        let streams = OperationStreams::with_inputs([&tensor]);
2428
2429        let client = tensor.client.clone();
2430        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
2431            client.create_empty_handle()
2432        });
2433
2434        client
2435            .register(
2436                streams,
2437                OperationIr::BaseFloat(BaseOperationIr::Cast(desc.clone())),
2438                CastOps::<B>::new(desc, dtype),
2439            )
2440            .output()
2441    }
2442
2443    fn float_unfold(
2444        tensor: FloatTensor<Self>,
2445        dim: usize,
2446        size: usize,
2447        step: usize,
2448    ) -> FloatTensor<Self> {
2449        #[derive(new, Debug)]
2450        struct UnfoldOps<B: FusionBackend> {
2451            desc: UnfoldOpIr,
2452            _b: PhantomData<B>,
2453        }
2454
2455        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2456            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2457                let input = handles.get_float_tensor::<B>(&self.desc.input);
2458                let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2459
2460                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2461            }
2462        }
2463
2464        let streams = OperationStreams::with_inputs([&tensor]);
2465
2466        let client = tensor.client.clone();
2467        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
2468            client.create_empty_handle()
2469        });
2470
2471        client
2472            .register(
2473                streams,
2474                OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())),
2475                UnfoldOps::<B>::new(desc),
2476            )
2477            .output()
2478    }
2479
2480    fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
2481        #[derive(new, Debug)]
2482        struct IsNanOps<B: FusionBackend> {
2483            desc: UnaryOpIr,
2484            _b: PhantomData<B>,
2485        }
2486        impl<B: FusionBackend> Operation<B::FusionRuntime> for IsNanOps<B> {
2487            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2488                let input = handles.get_float_tensor::<B>(&self.desc.input);
2489                let output = B::float_is_nan(input, self.desc.out.dtype.into());
2490                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2491            }
2492        }
2493
2494        let streams = OperationStreams::with_inputs([&tensor]);
2495
2496        let client = tensor.client.clone();
2497        let desc = UnaryOpIr::create_comparison(tensor.into_ir(), out_dtype.into(), || {
2498            client.create_empty_handle()
2499        });
2500
2501        client
2502            .register(
2503                streams,
2504                OperationIr::Float(desc.input.dtype, FloatOperationIr::IsNan(desc.clone())),
2505                IsNanOps::<B>::new(desc),
2506            )
2507            .output()
2508    }
2509
2510    fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
2511        #[derive(new, Debug)]
2512        struct IsInfOps<B: FusionBackend> {
2513            desc: UnaryOpIr,
2514            _b: PhantomData<B>,
2515        }
2516        impl<B: FusionBackend> Operation<B::FusionRuntime> for IsInfOps<B> {
2517            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2518                let input = handles.get_float_tensor::<B>(&self.desc.input);
2519                let output = B::float_is_inf(input, self.desc.out.dtype.into());
2520                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2521            }
2522        }
2523
2524        let streams = OperationStreams::with_inputs([&tensor]);
2525
2526        let client = tensor.client.clone();
2527        let desc = UnaryOpIr::create_comparison(tensor.into_ir(), out_dtype.into(), || {
2528            client.create_empty_handle()
2529        });
2530
2531        client
2532            .register(
2533                streams,
2534                OperationIr::Float(desc.input.dtype, FloatOperationIr::IsInf(desc.clone())),
2535                IsInfOps::<B>::new(desc),
2536            )
2537            .output()
2538    }
2539
2540    fn float_grid_sample_2d(
2541        tensor: FloatTensor<Self>,
2542        grid: FloatTensor<Self>,
2543        options: GridSampleOptions,
2544    ) -> FloatTensor<Self> {
2545        #[derive(new, Debug)]
2546        struct GridSample2dOps<B: FusionBackend> {
2547            desc: GridSample2dOpIr,
2548            _b: PhantomData<B>,
2549        }
2550
2551        impl<B: FusionBackend> Operation<B::FusionRuntime> for GridSample2dOps<B> {
2552            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2553                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2554                let grid = handles.get_float_tensor::<B>(&self.desc.grid);
2555                let output =
2556                    B::float_grid_sample_2d(tensor, grid, self.desc.options.clone().into());
2557                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2558            }
2559        }
2560
2561        let streams = OperationStreams::with_inputs([&tensor, &grid]);
2562
2563        let client = tensor.client.clone();
2564        let desc =
2565            GridSample2dOpIr::create(tensor.into_ir(), grid.into_ir(), options.into(), || {
2566                client.create_empty_handle()
2567            });
2568
2569        client
2570            .register(
2571                streams,
2572                OperationIr::Float(desc.out.dtype, FloatOperationIr::GridSample2d(desc.clone())),
2573                GridSample2dOps::<B>::new(desc),
2574            )
2575            .output()
2576    }
2577}