burn_fusion/ops/
float.rs

1use super::NoOp;
2use crate::{
3    Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops, get_client,
4    ops::binary::check_binary_op_types,
5    reduce_float_ops, reduce_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
6    stream::{OperationStreams, StreamId, execution::Operation},
7    unary_float_ops,
8};
9use burn_ir::*;
10use burn_tensor::ops::unfold::calculate_unfold_shape;
11use burn_tensor::{
12    Device, Distribution, Element, FloatDType, Shape, Slice, TensorData, TensorMetadata,
13    ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
14};
15use std::marker::PhantomData;
16
17impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
18    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
19        let stream = StreamId::current();
20        let client = get_client::<B>(&device.clone());
21        let dtype = data.dtype;
22        let tensor = B::float_from_data(data, device);
23        let shape = tensor.shape();
24
25        let handle = B::float_tensor_handle(tensor);
26        let out = client.register_tensor(handle, shape, stream, dtype);
27        let desc = out.to_ir_out();
28
29        client.register(
30            OperationStreams::default(),
31            OperationIr::Init(InitOperationIr { out: desc }),
32            NoOp::<B>::new(),
33        );
34
35        out
36    }
37
38    fn float_random(
39        shape: Shape,
40        distribution: Distribution,
41        device: &Device<Self>,
42    ) -> FloatTensor<Self> {
43        #[derive(new, Debug)]
44        struct RandomOps<B: FusionBackend> {
45            desc: RandomOpIr,
46            device: Device<B>,
47        }
48
49        impl<B: FusionBackend> Operation<B::FusionRuntime> for RandomOps<B> {
50            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
51                let output: B::FloatTensorPrimitive = B::float_random(
52                    self.desc.out.shape.clone(),
53                    self.desc.distribution,
54                    &self.device,
55                );
56                handles.register_float_tensor::<B>(&self.desc.out.id, output);
57            }
58        }
59
60        let client = get_client::<B>(&device.clone());
61        let out = client.tensor_uninitialized(shape, B::FloatElem::dtype());
62
63        let desc = RandomOpIr {
64            out: out.to_ir_out(),
65            distribution,
66        };
67        client.register(
68            OperationStreams::default(),
69            OperationIr::Float(
70                FloatElem::<Self>::dtype(),
71                FloatOperationIr::Random(desc.clone()),
72            ),
73            RandomOps::<B>::new(desc, device.clone()),
74        );
75
76        out
77    }
78
79    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
80        #[derive(new, Debug)]
81        struct ZerosOps<B: FusionBackend> {
82            out: TensorIr,
83            device: Device<B>,
84        }
85
86        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
87            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
88                let shape = self.out.shape.clone();
89                let output = B::float_zeros(shape, &self.device, self.out.dtype.into());
90                handles.register_float_tensor::<B>(&self.out.id, output);
91            }
92        }
93
94        let dtype = dtype.into();
95        let client = get_client::<B>(&device.clone());
96        let out = client.tensor_uninitialized(shape, dtype);
97
98        let desc = out.to_ir_out();
99        client.register(
100            OperationStreams::default(),
101            OperationIr::NumericFloat(dtype, NumericOperationIr::Zeros(desc.clone())),
102            ZerosOps::<B>::new(desc, device.clone()),
103        );
104
105        out
106    }
107
108    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
109        #[derive(new, Debug)]
110        struct OnesOps<B: FusionBackend> {
111            out: TensorIr,
112            device: Device<B>,
113        }
114
115        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
116            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
117                let shape = self.out.shape.clone();
118                let output = B::float_ones(shape, &self.device, self.out.dtype.into());
119                handles.register_float_tensor::<B>(&self.out.id, output);
120            }
121        }
122
123        let dtype = dtype.into();
124        let client = get_client::<B>(&device.clone());
125        let out = client.tensor_uninitialized(shape, dtype);
126
127        let desc = out.to_ir_out();
128        client.register(
129            OperationStreams::default(),
130            OperationIr::NumericFloat(dtype, NumericOperationIr::Ones(desc.clone())),
131            OnesOps::<B>::new(desc, device.clone()),
132        );
133
134        out
135    }
136
137    fn float_full(
138        shape: Shape,
139        fill_value: FloatElem<Self>,
140        device: &Device<Self>,
141        dtype: FloatDType,
142    ) -> FloatTensor<Self> {
143        #[derive(new, Debug)]
144        struct FullOps<B: FusionBackend> {
145            out: TensorIr,
146            elem: ScalarIr,
147            device: Device<B>,
148        }
149
150        impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
151            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
152                let shape = self.out.shape.clone();
153                let output: B::FloatTensorPrimitive =
154                    B::float_full(shape, self.elem.elem(), &self.device, self.out.dtype.into());
155                handles.register_float_tensor::<B>(&self.out.id, output);
156            }
157        }
158
159        let dtype = dtype.into();
160        let client = get_client::<B>(&device.clone());
161        let out = client.tensor_uninitialized(shape, dtype);
162
163        let desc = (out.to_ir_out(), ScalarIr::with_dtype(fill_value, &dtype));
164        client.register(
165            OperationStreams::default(),
166            OperationIr::NumericFloat(dtype, NumericOperationIr::Full(desc.clone())),
167            FullOps::<B>::new(desc.0, desc.1, device.clone()),
168        );
169
170        out
171    }
172
173    async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
174        tensor.into_data::<B>().await
175    }
176
177    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
178        tensor.client.device().clone()
179    }
180
181    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
182        let device_original: &B::Device = tensor.client.device();
183        let device_target: B::Device = device.clone();
184
185        if device_original == &device_target {
186            return tensor;
187        }
188
189        let id = tensor.stream;
190        let client_target = get_client::<B>(&device_target);
191        let client_original = tensor.client.clone();
192
193        client_original
194            .clone()
195            .change_client_float::<B>(tensor.into_ir(), client_target, id)
196    }
197
198    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
199        #[derive(new, Debug)]
200        struct IntoIntOps<B: FusionBackend> {
201            desc: UnaryOpIr,
202            _b: PhantomData<B>,
203        }
204
205        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
206            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
207                let input = handles.get_float_tensor::<B>(&self.desc.input);
208                let output = B::float_into_int(input);
209
210                handles.register_int_tensor::<B>(&self.desc.out.id, output);
211            }
212        }
213
214        let mut streams = OperationStreams::default();
215        streams.tensor(&tensor);
216        let dtype = tensor.dtype;
217        let out = tensor
218            .client
219            .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype());
220
221        let desc = UnaryOpIr {
222            input: tensor.into_ir(),
223            out: out.to_ir_out(),
224        };
225        out.client.register(
226            streams,
227            OperationIr::Float(dtype, FloatOperationIr::IntoInt(desc.clone())),
228            IntoIntOps::<B>::new(desc),
229        );
230
231        out
232    }
233
234    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
235        #[derive(new, Debug)]
236        struct EmptyOps<B: FusionBackend> {
237            desc: TensorIr,
238            device: Device<B>,
239        }
240
241        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
242            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
243                let output = B::float_empty(
244                    self.desc.shape.clone(),
245                    &self.device,
246                    self.desc.dtype.into(),
247                );
248                handles.register_float_tensor::<B>(&self.desc.id, output);
249            }
250        }
251
252        let client = get_client::<B>(&device.clone());
253        let out = client.tensor_uninitialized(shape.clone(), dtype.into());
254
255        let desc = out.to_ir_out();
256
257        client.register(
258            OperationStreams::default(),
259            OperationIr::BaseFloat(BaseOperationIr::Empty(desc.clone())),
260            EmptyOps::<B>::new(desc, device.clone()),
261        );
262
263        out
264    }
265
266    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
267        binary_float_ops!(AddOps, B::float_add);
268
269        let mut streams = OperationStreams::default();
270        streams.tensor(&lhs);
271        streams.tensor(&rhs);
272        let dtype = lhs.dtype;
273        let out = lhs
274            .client
275            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
276
277        let desc = BinaryOpIr {
278            lhs: lhs.into_ir(),
279            rhs: rhs.into_ir(),
280            out: out.to_ir_out(),
281        };
282
283        out.client.register(
284            streams,
285            OperationIr::NumericFloat(dtype, NumericOperationIr::Add(desc.clone())),
286            AddOps::<B>::new(desc),
287        );
288
289        out
290    }
291
292    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
293        scalar_float_ops!(AddOps, B::float_add_scalar);
294
295        let mut streams = OperationStreams::default();
296        streams.tensor(&lhs);
297        let dtype = lhs.dtype;
298        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
299
300        let desc = ScalarOpIr {
301            lhs: lhs.into_ir(),
302            rhs: ScalarIr::with_dtype(rhs, &dtype),
303            out: out.to_ir_out(),
304        };
305        out.client.register(
306            streams,
307            OperationIr::NumericFloat(dtype, NumericOperationIr::AddScalar(desc.clone())),
308            AddOps::<B>::new(desc),
309        );
310
311        out
312    }
313
314    fn float_clamp(
315        tensor: FloatTensor<Self>,
316        min: FloatElem<Self>,
317        max: FloatElem<Self>,
318    ) -> FloatTensor<Self> {
319        #[derive(new, Debug)]
320        struct ClampOps<B: FusionBackend> {
321            desc: ClampOpIr,
322            _b: PhantomData<B>,
323        }
324
325        impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
326            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
327                let input = handles.get_float_tensor::<B>(&self.desc.tensor);
328                let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem());
329
330                handles.register_float_tensor::<B>(&self.desc.out.id, output);
331            }
332        }
333
334        let mut streams = OperationStreams::default();
335        streams.tensor(&tensor);
336        let dtype = tensor.dtype;
337        let out = tensor
338            .client
339            .tensor_uninitialized(tensor.shape.clone(), dtype);
340
341        let desc = ClampOpIr {
342            tensor: tensor.into_ir(),
343            min: ScalarIr::with_dtype(min, &dtype),
344            max: ScalarIr::with_dtype(max, &dtype),
345            out: out.to_ir_out(),
346        };
347        out.client.register(
348            streams,
349            OperationIr::NumericFloat(dtype, NumericOperationIr::Clamp(desc.clone())),
350            ClampOps::<B>::new(desc),
351        );
352
353        out
354    }
355
356    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
357        binary_float_ops!(SubOps, B::float_sub);
358
359        let mut streams = OperationStreams::default();
360        streams.tensor(&lhs);
361        streams.tensor(&rhs);
362        let dtype = lhs.dtype;
363        let out = lhs
364            .client
365            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
366
367        let desc = BinaryOpIr {
368            lhs: lhs.into_ir(),
369            rhs: rhs.into_ir(),
370            out: out.to_ir_out(),
371        };
372        out.client.register(
373            streams,
374            OperationIr::NumericFloat(dtype, NumericOperationIr::Sub(desc.clone())),
375            SubOps::<B>::new(desc),
376        );
377
378        out
379    }
380
381    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
382        scalar_float_ops!(SubOps, B::float_sub_scalar);
383
384        let mut streams = OperationStreams::default();
385        streams.tensor(&lhs);
386        let dtype = lhs.dtype;
387        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
388        let desc = ScalarOpIr {
389            lhs: lhs.into_ir(),
390            rhs: ScalarIr::with_dtype(rhs, &dtype),
391            out: out.to_ir_out(),
392        };
393
394        out.client.register(
395            streams,
396            OperationIr::NumericFloat(dtype, NumericOperationIr::SubScalar(desc.clone())),
397            SubOps::<B>::new(desc),
398        );
399
400        out
401    }
402
403    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
404        binary_float_ops!(MulOps, B::float_mul);
405
406        let mut streams = OperationStreams::default();
407        streams.tensor(&lhs);
408        streams.tensor(&rhs);
409        let dtype = lhs.dtype;
410        let out = lhs
411            .client
412            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
413
414        let desc = BinaryOpIr {
415            lhs: lhs.into_ir(),
416            rhs: rhs.into_ir(),
417            out: out.to_ir_out(),
418        };
419        out.client.register(
420            streams,
421            OperationIr::NumericFloat(dtype, NumericOperationIr::Mul(desc.clone())),
422            MulOps::<B>::new(desc),
423        );
424
425        out
426    }
427
428    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
429        scalar_float_ops!(MulOps, B::float_mul_scalar);
430
431        let mut streams = OperationStreams::default();
432        streams.tensor(&lhs);
433        let dtype = lhs.dtype;
434        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
435
436        let desc = ScalarOpIr {
437            lhs: lhs.into_ir(),
438            rhs: ScalarIr::with_dtype(rhs, &dtype),
439            out: out.to_ir_out(),
440        };
441        out.client.register(
442            streams,
443            OperationIr::NumericFloat(dtype, NumericOperationIr::MulScalar(desc.clone())),
444            MulOps::<B>::new(desc),
445        );
446
447        out
448    }
449
450    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
451        binary_float_ops!(DivOps, B::float_div);
452
453        let mut streams = OperationStreams::default();
454        streams.tensor(&lhs);
455        streams.tensor(&rhs);
456        let dtype = lhs.dtype;
457        let out = lhs
458            .client
459            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
460
461        let desc = BinaryOpIr {
462            lhs: lhs.into_ir(),
463            rhs: rhs.into_ir(),
464            out: out.to_ir_out(),
465        };
466        out.client.register(
467            streams,
468            OperationIr::NumericFloat(dtype, NumericOperationIr::Div(desc.clone())),
469            DivOps::<B>::new(desc),
470        );
471
472        out
473    }
474
475    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
476        scalar_float_ops!(DivOps, B::float_div_scalar);
477
478        let mut streams = OperationStreams::default();
479        streams.tensor(&lhs);
480        let dtype = lhs.dtype;
481        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
482
483        let desc = ScalarOpIr {
484            lhs: lhs.into_ir(),
485            rhs: ScalarIr::with_dtype(rhs, &dtype),
486            out: out.to_ir_out(),
487        };
488        out.client.register(
489            streams,
490            OperationIr::NumericFloat(dtype, NumericOperationIr::DivScalar(desc.clone())),
491            DivOps::<B>::new(desc),
492        );
493
494        out
495    }
496
497    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
498        binary_float_ops!(ModOps, B::float_remainder);
499
500        let mut streams = OperationStreams::default();
501        streams.tensor(&lhs);
502        streams.tensor(&rhs);
503        let dtype = lhs.dtype;
504        let out = lhs
505            .client
506            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
507
508        let desc = BinaryOpIr {
509            lhs: lhs.into_ir(),
510            rhs: rhs.into_ir(),
511            out: out.to_ir_out(),
512        };
513        out.client.register(
514            streams,
515            OperationIr::NumericFloat(dtype, NumericOperationIr::Rem(desc.clone())),
516            ModOps::<B>::new(desc),
517        );
518
519        out
520    }
521
522    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
523        scalar_float_ops!(ModOps, B::float_remainder_scalar);
524
525        let mut streams = OperationStreams::default();
526        streams.tensor(&lhs);
527        let dtype = lhs.dtype;
528        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
529
530        let desc = ScalarOpIr {
531            lhs: lhs.into_ir(),
532            rhs: ScalarIr::with_dtype(rhs, &dtype),
533            out: out.to_ir_out(),
534        };
535        out.client.register(
536            streams,
537            OperationIr::NumericFloat(dtype, NumericOperationIr::RemScalar(desc.clone())),
538            ModOps::<B>::new(desc),
539        );
540
541        out
542    }
543
544    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
545        binary_float_ops!(MatmulOps, B::float_matmul);
546
547        let mut streams = OperationStreams::default();
548        streams.tensor(&lhs);
549        streams.tensor(&rhs);
550        let dtype = lhs.dtype;
551        let shape = Shape::matmul(&lhs.shape, &rhs.shape).unwrap();
552
553        let out = lhs.client.tensor_uninitialized(shape, dtype);
554        let desc = BinaryOpIr {
555            lhs: lhs.into_ir(),
556            rhs: rhs.into_ir(),
557            out: out.to_ir_out(),
558        };
559
560        out.client.register(
561            streams,
562            OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
563            MatmulOps::<B>::new(desc),
564        );
565
566        out
567    }
568
569    fn float_cross(
570        lhs: FloatTensor<Self>,
571        rhs: FloatTensor<Self>,
572        dim: usize,
573    ) -> FloatTensor<Self> {
574        #[derive(new, Debug)]
575        struct CrossOps<B: FusionBackend> {
576            desc: CrossOpIr,
577            _b: PhantomData<B>,
578        }
579
580        impl<B: FusionBackend> Operation<B::FusionRuntime> for CrossOps<B> {
581            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
582                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
583                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
584                let output = B::float_cross(lhs, rhs, self.desc.dim);
585                handles.register_float_tensor::<B>(&self.desc.out.id, output);
586            }
587        }
588
589        let mut streams = OperationStreams::default();
590        streams.tensor(&lhs);
591        streams.tensor(&rhs);
592        let dtype = lhs.dtype;
593        let out = lhs
594            .client
595            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
596        let desc = CrossOpIr {
597            lhs: lhs.into_ir(),
598            rhs: rhs.into_ir(),
599            out: out.to_ir_out(),
600            dim,
601        };
602
603        out.client.register(
604            streams,
605            OperationIr::Float(dtype, FloatOperationIr::Cross(desc.clone())),
606            CrossOps::<B>::new(desc),
607        );
608
609        out
610    }
611
612    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
613        #[derive(new, Debug)]
614        struct SwapDimsOps<B: FusionBackend> {
615            desc: SwapDimsOpIr,
616            _b: PhantomData<B>,
617        }
618
619        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
620            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
621                let input = handles.get_float_tensor::<B>(&self.desc.input);
622                let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);
623                handles.register_float_tensor::<B>(&self.desc.out.id, output);
624            }
625        }
626
627        let mut streams = OperationStreams::default();
628        streams.tensor(&tensor);
629
630        let dtype = tensor.dtype;
631        let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
632        let mut out = tensor.client.tensor_uninitialized(shape, dtype);
633
634        let desc = SwapDimsOpIr {
635            input: tensor.into_ir(),
636            dim1,
637            dim2,
638            out: out.to_ir_out(),
639        };
640        out.client.register(
641            streams,
642            OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
643            SwapDimsOps::<B>::new(desc),
644        );
645        out.stream = StreamId::current();
646
647        out
648    }
649
650    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
651        if tensor.shape == shape {
652            return tensor;
653        }
654
655        #[derive(new, Debug)]
656        struct ReshapeDimsOps<B: FusionBackend> {
657            desc: UnaryOpIr,
658            _b: PhantomData<B>,
659        }
660
661        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
662            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
663                let input = handles.get_float_tensor::<B>(&self.desc.input);
664                let output = B::float_reshape(input, self.desc.out.shape.clone());
665                handles.register_float_tensor::<B>(&self.desc.out.id, output);
666            }
667        }
668
669        let mut streams = OperationStreams::default();
670        streams.tensor(&tensor);
671        let dtype = tensor.dtype;
672        let out = tensor.client.tensor_uninitialized(shape, dtype);
673
674        let desc = UnaryOpIr {
675            input: tensor.into_ir(),
676            out: out.to_ir_out(),
677        };
678        out.client.register(
679            streams,
680            OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
681            ReshapeDimsOps::<B>::new(desc),
682        );
683
684        out
685    }
686
687    fn float_gather(
688        dim: usize,
689        tensor: FloatTensor<Self>,
690        indices: IntTensor<Self>,
691    ) -> FloatTensor<Self> {
692        #[derive(new, Debug)]
693        struct GatherOps<B: FusionBackend> {
694            desc: GatherOpIr,
695            _b: PhantomData<B>,
696        }
697
698        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
699            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
700                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
701                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
702
703                let output = B::float_gather(self.desc.dim, tensor, indices);
704                handles.register_float_tensor::<B>(&self.desc.out.id, output);
705            }
706        }
707
708        let mut streams = OperationStreams::default();
709        streams.tensor(&tensor);
710        streams.tensor(&indices);
711
712        let dtype = tensor.dtype;
713        let shape = indices.shape.clone();
714        let out = tensor.client.tensor_uninitialized(shape, dtype);
715
716        let desc = GatherOpIr {
717            tensor: tensor.into_ir(),
718            dim,
719            indices: indices.into_ir(),
720            out: out.to_ir_out(),
721        };
722        out.client.register(
723            streams,
724            OperationIr::NumericFloat(dtype, NumericOperationIr::Gather(desc.clone())),
725            GatherOps::<B>::new(desc),
726        );
727
728        out
729    }
730
731    fn float_scatter(
732        dim: usize,
733        tensor: FloatTensor<Self>,
734        indices: IntTensor<Self>,
735        value: FloatTensor<Self>,
736    ) -> FloatTensor<Self> {
737        #[derive(new, Debug)]
738        struct ScatterOps<B: FusionBackend> {
739            desc: ScatterOpIr,
740            _b: PhantomData<B>,
741        }
742
743        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
744            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
745                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
746                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
747                let value = handles.get_float_tensor::<B>(&self.desc.value);
748
749                let output = B::float_scatter(self.desc.dim, tensor, indices, value);
750
751                handles.register_float_tensor::<B>(&self.desc.out.id, output);
752            }
753        }
754
755        let mut streams = OperationStreams::default();
756        streams.tensor(&tensor);
757        streams.tensor(&indices);
758        streams.tensor(&value);
759
760        let shape = tensor.shape.clone();
761        let dtype = tensor.dtype;
762        let out = tensor.client.tensor_uninitialized(shape, dtype);
763
764        let desc = ScatterOpIr {
765            tensor: tensor.into_ir(),
766            dim,
767            indices: indices.into_ir(),
768            value: value.into_ir(),
769            out: out.to_ir_out(),
770        };
771        // Check that both float tensors have the same type
772        check_binary_op_types(&desc.tensor, &desc.value).unwrap();
773        out.client.register(
774            streams,
775            OperationIr::NumericFloat(dtype, NumericOperationIr::Scatter(desc.clone())),
776            ScatterOps::<B>::new(desc),
777        );
778
779        out
780    }
781
782    fn float_select(
783        tensor: FloatTensor<Self>,
784        dim: usize,
785        indices: IntTensor<Self>,
786    ) -> FloatTensor<Self> {
787        #[derive(new, Debug)]
788        struct SelectOps<B: FusionBackend> {
789            desc: SelectOpIr,
790            _b: PhantomData<B>,
791        }
792
793        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
794            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
795                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
796                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
797
798                let output = B::float_select(tensor, self.desc.dim, indices);
799
800                handles.register_float_tensor::<B>(&self.desc.out.id, output);
801            }
802        }
803
804        let mut streams = OperationStreams::default();
805        streams.tensor(&tensor);
806        streams.tensor(&indices);
807
808        let dtype = tensor.dtype;
809        let mut shape = tensor.shape.clone();
810        shape[dim] = indices.shape[0];
811        let out = tensor.client.tensor_uninitialized(shape, dtype);
812        let desc = SelectOpIr {
813            tensor: tensor.into_ir(),
814            dim,
815            indices: indices.into_ir(),
816            out: out.to_ir_out(),
817        };
818        out.client.register(
819            streams,
820            OperationIr::NumericFloat(dtype, NumericOperationIr::Select(desc.clone())),
821            SelectOps::<B>::new(desc),
822        );
823
824        out
825    }
826
827    fn float_select_assign(
828        tensor: FloatTensor<Self>,
829        dim: usize,
830        indices: IntTensor<Self>,
831        value: FloatTensor<Self>,
832    ) -> FloatTensor<Self> {
833        #[derive(new, Debug)]
834        struct SelectAssignOps<B: FusionBackend> {
835            desc: SelectAssignOpIr,
836            _b: PhantomData<B>,
837        }
838
839        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
840            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
841                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
842                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
843                let value = handles.get_float_tensor::<B>(&self.desc.value);
844
845                let output = B::float_select_assign(tensor, self.desc.dim, indices, value);
846
847                handles.register_float_tensor::<B>(&self.desc.out.id, output);
848            }
849        }
850
851        let mut streams = OperationStreams::default();
852        streams.tensor(&tensor);
853        streams.tensor(&indices);
854        streams.tensor(&value);
855
856        let dtype = tensor.dtype;
857        let shape = tensor.shape.clone();
858        let out = tensor.client.tensor_uninitialized(shape, dtype);
859
860        let desc = SelectAssignOpIr {
861            tensor: tensor.into_ir(),
862            dim,
863            indices: indices.into_ir(),
864            value: value.into_ir(),
865            out: out.to_ir_out(),
866        };
867        // Check that both float tensors have the same type
868        check_binary_op_types(&desc.tensor, &desc.value).unwrap();
869        out.client.register(
870            streams,
871            OperationIr::NumericFloat(dtype, NumericOperationIr::SelectAssign(desc.clone())),
872            SelectAssignOps::<B>::new(desc),
873        );
874
875        out
876    }
877
878    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
879        #[derive(new, Debug)]
880        struct SliceOps<B: FusionBackend> {
881            desc: SliceOpIr,
882            _b: PhantomData<B>,
883        }
884
885        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
886            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
887                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
888
889                let output = B::float_slice(tensor, self.desc.ranges.as_slice());
890
891                handles.register_float_tensor::<B>(&self.desc.out.id, output);
892            }
893        }
894        let mut streams = OperationStreams::default();
895        streams.tensor(&tensor);
896        let dtype = tensor.dtype;
897        let shape = tensor.shape.clone().slice(slices).unwrap();
898
899        let out = tensor.client.tensor_uninitialized(shape, dtype);
900
901        let desc = SliceOpIr {
902            tensor: tensor.into_ir(),
903            ranges: slices.to_vec(),
904            out: out.to_ir_out(),
905        };
906        out.client.register(
907            streams,
908            OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
909            SliceOps::<B>::new(desc),
910        );
911
912        out
913    }
914
915    fn float_slice_assign(
916        tensor: FloatTensor<Self>,
917        ranges: &[burn_tensor::Slice],
918        value: FloatTensor<Self>,
919    ) -> FloatTensor<Self> {
920        #[derive(new, Debug)]
921        struct SliceAssignOps<B: FusionBackend> {
922            desc: SliceAssignOpIr,
923            _b: PhantomData<B>,
924        }
925
926        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
927            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
928                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
929                let value = handles.get_float_tensor::<B>(&self.desc.value);
930
931                let output = B::float_slice_assign(tensor, self.desc.ranges.as_slice(), value);
932
933                handles.register_float_tensor::<B>(&self.desc.out.id, output);
934            }
935        }
936
937        let mut streams = OperationStreams::default();
938        streams.tensor(&tensor);
939        streams.tensor(&value);
940
941        let dtype = tensor.dtype;
942        let shape = tensor.shape.clone();
943        let out = tensor.client.tensor_uninitialized(shape, dtype);
944
945        let desc = SliceAssignOpIr {
946            tensor: tensor.into_ir(),
947            ranges: ranges.to_vec(),
948            value: value.into_ir(),
949            out: out.to_ir_out(),
950        };
951        // Check that both float tensors have the same type
952        check_binary_op_types(&desc.tensor, &desc.value).unwrap();
953        out.client.register(
954            streams,
955            OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc.clone())),
956            SliceAssignOps::<B>::new(desc),
957        );
958
959        out
960    }
961
962    fn float_mask_where(
963        tensor: FloatTensor<Self>,
964        mask: BoolTensor<Self>,
965        value: FloatTensor<Self>,
966    ) -> FloatTensor<Self> {
967        #[derive(new, Debug)]
968        struct MaskWhereOps<B: FusionBackend> {
969            desc: MaskWhereOpIr,
970            _b: PhantomData<B>,
971        }
972
973        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
974            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
975                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
976                let value = handles.get_float_tensor::<B>(&self.desc.value);
977                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
978
979                let output = B::float_mask_where(tensor, mask, value);
980
981                handles.register_float_tensor::<B>(&self.desc.out.id, output);
982            }
983        }
984
985        let mut streams = OperationStreams::default();
986        streams.tensor(&tensor);
987        streams.tensor(&mask);
988        streams.tensor(&value);
989
990        let dtype = tensor.dtype;
991        let out = tensor
992            .client
993            .tensor_uninitialized(tensor.shape.broadcast(&mask.shape).unwrap(), dtype);
994
995        let desc = MaskWhereOpIr {
996            tensor: tensor.into_ir(),
997            value: value.into_ir(),
998            mask: mask.into_ir(),
999            out: out.to_ir_out(),
1000        };
1001        // Check that both float tensors have the same type
1002        check_binary_op_types(&desc.tensor, &desc.value).unwrap();
1003        out.client.register(
1004            streams,
1005            OperationIr::NumericFloat(dtype, NumericOperationIr::MaskWhere(desc.clone())),
1006            MaskWhereOps::<B>::new(desc),
1007        );
1008
1009        out
1010    }
1011
1012    fn float_mask_fill(
1013        tensor: FloatTensor<Self>,
1014        mask: BoolTensor<Self>,
1015        value: FloatElem<Self>,
1016    ) -> FloatTensor<Self> {
1017        #[derive(new, Debug)]
1018        struct MaskFillOps<B: FusionBackend> {
1019            desc: MaskFillOpIr,
1020            _b: PhantomData<B>,
1021        }
1022
1023        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
1024            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1025                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
1026                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
1027
1028                let output = B::float_mask_fill(tensor, mask, self.desc.value.elem());
1029
1030                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1031            }
1032        }
1033
1034        let mut streams = OperationStreams::default();
1035        streams.tensor(&tensor);
1036        streams.tensor(&mask);
1037
1038        let dtype = tensor.dtype;
1039        let shape = tensor.shape.clone();
1040        let out = tensor.client.tensor_uninitialized(shape, dtype);
1041        let desc = MaskFillOpIr {
1042            tensor: tensor.into_ir(),
1043            value: ScalarIr::with_dtype(value, &dtype),
1044            mask: mask.into_ir(),
1045            out: out.to_ir_out(),
1046        };
1047        out.client.register(
1048            streams,
1049            OperationIr::NumericFloat(dtype, NumericOperationIr::MaskFill(desc.clone())),
1050            MaskFillOps::<B>::new(desc),
1051        );
1052
1053        out
1054    }
1055
1056    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1057        binary_float_cmp_ops!(EqualOps, B::float_equal);
1058
1059        let mut streams = OperationStreams::default();
1060        streams.tensor(&lhs);
1061        streams.tensor(&rhs);
1062        let out = lhs.client.tensor_uninitialized(
1063            lhs.shape.broadcast(&rhs.shape).unwrap(),
1064            B::BoolElem::dtype(),
1065        );
1066
1067        let desc = BinaryOpIr {
1068            lhs: lhs.into_ir(),
1069            rhs: rhs.into_ir(),
1070            out: out.to_ir_out(),
1071        };
1072        out.client.register(
1073            streams,
1074            OperationIr::BaseFloat(BaseOperationIr::Equal(desc.clone())),
1075            EqualOps::<B>::new(desc),
1076        );
1077
1078        out
1079    }
1080
1081    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1082        scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem);
1083
1084        let mut streams = OperationStreams::default();
1085        streams.tensor(&lhs);
1086        let dtype = lhs.dtype;
1087        let out = lhs
1088            .client
1089            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1090
1091        let desc = ScalarOpIr {
1092            lhs: lhs.into_ir(),
1093            rhs: ScalarIr::with_dtype(rhs, &dtype),
1094            out: out.to_ir_out(),
1095        };
1096        out.client.register(
1097            streams,
1098            OperationIr::NumericFloat(dtype, NumericOperationIr::EqualElem(desc.clone())),
1099            EqualElemOps::<B>::new(desc),
1100        );
1101
1102        out
1103    }
1104
1105    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1106        binary_float_cmp_ops!(GreaterOps, B::float_greater);
1107
1108        let mut streams = OperationStreams::default();
1109        streams.tensor(&lhs);
1110        streams.tensor(&rhs);
1111        let dtype = lhs.dtype;
1112        let out = lhs.client.tensor_uninitialized(
1113            lhs.shape.broadcast(&rhs.shape).unwrap(),
1114            B::BoolElem::dtype(),
1115        );
1116
1117        let desc = BinaryOpIr {
1118            lhs: lhs.into_ir(),
1119            rhs: rhs.into_ir(),
1120            out: out.to_ir_out(),
1121        };
1122        out.client.register(
1123            streams,
1124            OperationIr::NumericFloat(dtype, NumericOperationIr::Greater(desc.clone())),
1125            GreaterOps::<B>::new(desc),
1126        );
1127
1128        out
1129    }
1130
1131    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1132        scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem);
1133
1134        let mut streams = OperationStreams::default();
1135        streams.tensor(&lhs);
1136        let dtype = lhs.dtype;
1137        let out = lhs
1138            .client
1139            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1140
1141        let desc = ScalarOpIr {
1142            lhs: lhs.into_ir(),
1143            rhs: ScalarIr::with_dtype(rhs, &dtype),
1144            out: out.to_ir_out(),
1145        };
1146        out.client.register(
1147            streams,
1148            OperationIr::NumericFloat(dtype, NumericOperationIr::GreaterElem(desc.clone())),
1149            GreaterElemOps::<B>::new(desc),
1150        );
1151
1152        out
1153    }
1154
1155    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1156        binary_float_cmp_ops!(GreaterEqualOps, B::float_greater_equal);
1157
1158        let mut streams = OperationStreams::default();
1159        streams.tensor(&lhs);
1160        streams.tensor(&rhs);
1161        let dtype = lhs.dtype;
1162        let out = lhs.client.tensor_uninitialized(
1163            lhs.shape.broadcast(&rhs.shape).unwrap(),
1164            B::BoolElem::dtype(),
1165        );
1166
1167        let desc = BinaryOpIr {
1168            lhs: lhs.into_ir(),
1169            rhs: rhs.into_ir(),
1170            out: out.to_ir_out(),
1171        };
1172        out.client.register(
1173            streams,
1174            OperationIr::NumericFloat(dtype, NumericOperationIr::GreaterEqual(desc.clone())),
1175            GreaterEqualOps::<B>::new(desc),
1176        );
1177
1178        out
1179    }
1180
1181    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1182        scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem);
1183
1184        let mut streams = OperationStreams::default();
1185        streams.tensor(&lhs);
1186        let dtype = lhs.dtype;
1187        let out = lhs
1188            .client
1189            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1190
1191        let desc = ScalarOpIr {
1192            lhs: lhs.into_ir(),
1193            rhs: ScalarIr::with_dtype(rhs, &dtype),
1194            out: out.to_ir_out(),
1195        };
1196        out.client.register(
1197            streams,
1198            OperationIr::NumericFloat(dtype, NumericOperationIr::GreaterEqualElem(desc.clone())),
1199            GreaterEqualElemOps::<B>::new(desc),
1200        );
1201
1202        out
1203    }
1204
1205    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1206        binary_float_cmp_ops!(LowerOps, B::float_lower);
1207
1208        let mut streams = OperationStreams::default();
1209        streams.tensor(&lhs);
1210        streams.tensor(&rhs);
1211        let dtype = lhs.dtype;
1212        let out = lhs.client.tensor_uninitialized(
1213            lhs.shape.broadcast(&rhs.shape).unwrap(),
1214            B::BoolElem::dtype(),
1215        );
1216
1217        let desc = BinaryOpIr {
1218            lhs: lhs.into_ir(),
1219            rhs: rhs.into_ir(),
1220            out: out.to_ir_out(),
1221        };
1222        out.client.register(
1223            streams,
1224            OperationIr::NumericFloat(dtype, NumericOperationIr::Lower(desc.clone())),
1225            LowerOps::<B>::new(desc),
1226        );
1227
1228        out
1229    }
1230
1231    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1232        scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem);
1233
1234        let mut streams = OperationStreams::default();
1235        streams.tensor(&lhs);
1236        let dtype = lhs.dtype;
1237        let out = lhs
1238            .client
1239            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1240
1241        let desc = ScalarOpIr {
1242            lhs: lhs.into_ir(),
1243            rhs: ScalarIr::with_dtype(rhs, &dtype),
1244            out: out.to_ir_out(),
1245        };
1246        out.client.register(
1247            streams,
1248            OperationIr::NumericFloat(dtype, NumericOperationIr::LowerElem(desc.clone())),
1249            LowerElemOps::<B>::new(desc),
1250        );
1251
1252        out
1253    }
1254
1255    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1256        binary_float_cmp_ops!(LowerEqualOps, B::float_lower_equal);
1257
1258        let mut streams = OperationStreams::default();
1259        streams.tensor(&lhs);
1260        streams.tensor(&rhs);
1261        let dtype = lhs.dtype;
1262        let out = lhs.client.tensor_uninitialized(
1263            lhs.shape.broadcast(&rhs.shape).unwrap(),
1264            B::BoolElem::dtype(),
1265        );
1266
1267        let desc = BinaryOpIr {
1268            lhs: lhs.into_ir(),
1269            rhs: rhs.into_ir(),
1270            out: out.to_ir_out(),
1271        };
1272        out.client.register(
1273            streams,
1274            OperationIr::NumericFloat(dtype, NumericOperationIr::LowerEqual(desc.clone())),
1275            LowerEqualOps::<B>::new(desc),
1276        );
1277
1278        out
1279    }
1280
1281    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1282        scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem);
1283
1284        let mut streams = OperationStreams::default();
1285        streams.tensor(&lhs);
1286        let dtype = lhs.dtype;
1287        let out = lhs
1288            .client
1289            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1290
1291        let desc = ScalarOpIr {
1292            lhs: lhs.into_ir(),
1293            rhs: ScalarIr::with_dtype(rhs, &dtype),
1294            out: out.to_ir_out(),
1295        };
1296        out.client.register(
1297            streams,
1298            OperationIr::NumericFloat(dtype, NumericOperationIr::LowerEqualElem(desc.clone())),
1299            LowerEqualElemOps::<B>::new(desc),
1300        );
1301
1302        out
1303    }
1304
1305    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1306        unary_float_ops!(SumOps, B::float_sum, reduce);
1307
1308        let mut streams = OperationStreams::default();
1309        streams.tensor(&tensor);
1310        let dtype = tensor.dtype;
1311        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1312
1313        let desc = UnaryOpIr {
1314            input: tensor.into_ir(),
1315            out: out.to_ir_out(),
1316        };
1317        out.client.register(
1318            streams,
1319            OperationIr::NumericFloat(dtype, NumericOperationIr::Sum(desc.clone())),
1320            SumOps::<B>::new(desc),
1321        );
1322
1323        out
1324    }
1325
1326    fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
1327        reduce_float_ops!(SumDimOps, B::float_sum_dim);
1328
1329        let mut streams = OperationStreams::default();
1330        streams.tensor(&tensor);
1331        let dtype = tensor.dtype;
1332        let mut shape = tensor.shape.clone();
1333        shape[axis] = 1;
1334        let out = tensor.client.tensor_uninitialized(shape, dtype);
1335
1336        let desc = ReduceDimOpIr {
1337            input: tensor.into_ir(),
1338            out: out.to_ir_out(),
1339            axis,
1340        };
1341        out.client.register(
1342            streams,
1343            OperationIr::NumericFloat(dtype, NumericOperationIr::SumDim(desc.clone())),
1344            SumDimOps::<B>::new(desc),
1345        );
1346
1347        out
1348    }
1349
1350    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1351        unary_float_ops!(ProdOps, B::float_prod, reduce);
1352
1353        let mut streams = OperationStreams::default();
1354        streams.tensor(&tensor);
1355        let dtype = tensor.dtype;
1356        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1357
1358        let desc = UnaryOpIr {
1359            input: tensor.into_ir(),
1360            out: out.to_ir_out(),
1361        };
1362        out.client.register(
1363            streams,
1364            OperationIr::NumericFloat(dtype, NumericOperationIr::Prod(desc.clone())),
1365            ProdOps::<B>::new(desc),
1366        );
1367
1368        out
1369    }
1370
1371    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1372        reduce_float_ops!(ProdDimOps, B::float_prod_dim);
1373
1374        let mut streams = OperationStreams::default();
1375        streams.tensor(&tensor);
1376        let dtype = tensor.dtype;
1377        let mut shape = tensor.shape.clone();
1378        shape[dim] = 1;
1379        let out = tensor.client.tensor_uninitialized(shape, dtype);
1380
1381        let desc = ReduceDimOpIr {
1382            input: tensor.into_ir(),
1383            axis: dim,
1384            out: out.to_ir_out(),
1385        };
1386
1387        out.client.register(
1388            streams,
1389            OperationIr::NumericFloat(dtype, NumericOperationIr::ProdDim(desc.clone())),
1390            ProdDimOps::<B>::new(desc),
1391        );
1392
1393        out
1394    }
1395
1396    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1397        unary_float_ops!(MeanOps, B::float_mean, reduce);
1398
1399        let mut streams = OperationStreams::default();
1400        streams.tensor(&tensor);
1401        let dtype = tensor.dtype;
1402        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1403
1404        let desc = UnaryOpIr {
1405            input: tensor.into_ir(),
1406            out: out.to_ir_out(),
1407        };
1408        out.client.register(
1409            streams,
1410            OperationIr::NumericFloat(dtype, NumericOperationIr::Mean(desc.clone())),
1411            MeanOps::<B>::new(desc),
1412        );
1413
1414        out
1415    }
1416
1417    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1418        reduce_float_ops!(MeanDimOps, B::float_mean_dim);
1419
1420        let mut streams = OperationStreams::default();
1421        streams.tensor(&tensor);
1422        let dtype = tensor.dtype;
1423        let mut shape = tensor.shape.clone();
1424        shape[dim] = 1;
1425        let out = tensor.client.tensor_uninitialized(shape, dtype);
1426
1427        let desc = ReduceDimOpIr {
1428            input: tensor.into_ir(),
1429            axis: dim,
1430            out: out.to_ir_out(),
1431        };
1432        out.client.register(
1433            streams,
1434            OperationIr::NumericFloat(dtype, NumericOperationIr::MeanDim(desc.clone())),
1435            MeanDimOps::<B>::new(desc),
1436        );
1437
1438        out
1439    }
1440
1441    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1442        #[derive(new, Debug)]
1443        struct CumsumOps<B: FusionBackend> {
1444            desc: DimOpIr,
1445            _b: PhantomData<B>,
1446        }
1447
1448        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1449            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1450                let input = handles.get_float_tensor::<B>(&self.desc.input);
1451                let output = B::float_cumsum(input, self.desc.axis);
1452                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1453            }
1454        }
1455
1456        let mut streams = OperationStreams::default();
1457        streams.tensor(&tensor);
1458        let dtype = tensor.dtype;
1459        let shape = tensor.shape.clone();
1460        let out = tensor.client.tensor_uninitialized(shape, dtype);
1461
1462        let desc = DimOpIr {
1463            input: tensor.into_ir(),
1464            out: out.to_ir_out(),
1465            axis: dim,
1466        };
1467
1468        out.client.register(
1469            streams,
1470            OperationIr::BaseFloat(BaseOperationIr::CumSum(desc.clone())),
1471            CumsumOps::<B>::new(desc),
1472        );
1473
1474        out
1475    }
1476
1477    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1478        #[derive(new, Debug)]
1479        struct CumprodOps<B: FusionBackend> {
1480            desc: DimOpIr,
1481            _b: PhantomData<B>,
1482        }
1483
1484        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1485            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1486                let input = handles.get_float_tensor::<B>(&self.desc.input);
1487                let output = B::float_cumprod(input, self.desc.axis);
1488                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1489            }
1490        }
1491
1492        let mut streams = OperationStreams::default();
1493        streams.tensor(&tensor);
1494        let dtype = tensor.dtype;
1495        let shape = tensor.shape.clone();
1496        let out = tensor.client.tensor_uninitialized(shape, dtype);
1497
1498        let desc = DimOpIr {
1499            input: tensor.into_ir(),
1500            out: out.to_ir_out(),
1501            axis: dim,
1502        };
1503
1504        out.client.register(
1505            streams,
1506            OperationIr::BaseFloat(BaseOperationIr::CumProd(desc.clone())),
1507            CumprodOps::<B>::new(desc),
1508        );
1509
1510        out
1511    }
1512
1513    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1514        #[derive(new, Debug)]
1515        struct CumminOps<B: FusionBackend> {
1516            desc: DimOpIr,
1517            _b: PhantomData<B>,
1518        }
1519
1520        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1521            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1522                let input = handles.get_float_tensor::<B>(&self.desc.input);
1523                let output = B::float_cummin(input, self.desc.axis);
1524                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1525            }
1526        }
1527
1528        let mut streams = OperationStreams::default();
1529        streams.tensor(&tensor);
1530        let dtype = tensor.dtype;
1531        let shape = tensor.shape.clone();
1532        let out = tensor.client.tensor_uninitialized(shape, dtype);
1533
1534        let desc = DimOpIr {
1535            input: tensor.into_ir(),
1536            out: out.to_ir_out(),
1537            axis: dim,
1538        };
1539
1540        out.client.register(
1541            streams,
1542            OperationIr::BaseFloat(BaseOperationIr::CumMin(desc.clone())),
1543            CumminOps::<B>::new(desc),
1544        );
1545
1546        out
1547    }
1548
1549    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1550        #[derive(new, Debug)]
1551        struct CummaxOps<B: FusionBackend> {
1552            desc: DimOpIr,
1553            _b: PhantomData<B>,
1554        }
1555
1556        impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1557            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1558                let input = handles.get_float_tensor::<B>(&self.desc.input);
1559                let output = B::float_cummax(input, self.desc.axis);
1560                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1561            }
1562        }
1563
1564        let mut streams = OperationStreams::default();
1565        streams.tensor(&tensor);
1566        let dtype = tensor.dtype;
1567        let shape = tensor.shape.clone();
1568        let out = tensor.client.tensor_uninitialized(shape, dtype);
1569
1570        let desc = DimOpIr {
1571            input: tensor.into_ir(),
1572            out: out.to_ir_out(),
1573            axis: dim,
1574        };
1575
1576        out.client.register(
1577            streams,
1578            OperationIr::BaseFloat(BaseOperationIr::CumMax(desc.clone())),
1579            CummaxOps::<B>::new(desc),
1580        );
1581
1582        out
1583    }
1584
1585    fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
1586        unary_float_ops!(ExpOps, B::float_exp);
1587
1588        let mut streams = OperationStreams::default();
1589        streams.tensor(&lhs);
1590        let dtype = lhs.dtype;
1591        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
1592
1593        let desc = UnaryOpIr {
1594            input: lhs.into_ir(),
1595            out: out.to_ir_out(),
1596        };
1597        out.client.register(
1598            streams,
1599            OperationIr::Float(dtype, FloatOperationIr::Exp(desc.clone())),
1600            ExpOps::<B>::new(desc),
1601        );
1602
1603        out
1604    }
1605
1606    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1607        unary_float_ops!(LogOps, B::float_log);
1608
1609        let mut streams = OperationStreams::default();
1610        streams.tensor(&tensor);
1611        let dtype = tensor.dtype;
1612        let out = tensor
1613            .client
1614            .tensor_uninitialized(tensor.shape.clone(), dtype);
1615
1616        let desc = UnaryOpIr {
1617            input: tensor.into_ir(),
1618            out: out.to_ir_out(),
1619        };
1620        out.client.register(
1621            streams,
1622            OperationIr::Float(dtype, FloatOperationIr::Log(desc.clone())),
1623            LogOps::<B>::new(desc),
1624        );
1625
1626        out
1627    }
1628
1629    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1630        unary_float_ops!(Log1pOps, B::float_log1p);
1631
1632        let mut streams = OperationStreams::default();
1633        streams.tensor(&tensor);
1634        let dtype = tensor.dtype;
1635        let out = tensor
1636            .client
1637            .tensor_uninitialized(tensor.shape.clone(), dtype);
1638
1639        let desc = UnaryOpIr {
1640            input: tensor.into_ir(),
1641            out: out.to_ir_out(),
1642        };
1643        out.client.register(
1644            streams,
1645            OperationIr::Float(dtype, FloatOperationIr::Log1p(desc.clone())),
1646            Log1pOps::<B>::new(desc),
1647        );
1648
1649        out
1650    }
1651
1652    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
1653        scalar_float_ops!(PowfOps, B::float_powf_scalar);
1654
1655        let mut streams = OperationStreams::default();
1656        streams.tensor(&lhs);
1657        let dtype = lhs.dtype;
1658        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
1659
1660        let desc = ScalarOpIr {
1661            lhs: lhs.into_ir(),
1662            rhs: ScalarIr::with_dtype(rhs, &dtype),
1663            out: out.to_ir_out(),
1664        };
1665        out.client.register(
1666            streams,
1667            OperationIr::Float(dtype, FloatOperationIr::PowfScalar(desc.clone())),
1668            PowfOps::<B>::new(desc),
1669        );
1670
1671        out
1672    }
1673
1674    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1675        unary_float_ops!(SqrtOps, B::float_sqrt);
1676
1677        let mut streams = OperationStreams::default();
1678        streams.tensor(&tensor);
1679        let dtype = tensor.dtype;
1680        let out = tensor
1681            .client
1682            .tensor_uninitialized(tensor.shape.clone(), dtype);
1683
1684        let desc = UnaryOpIr {
1685            input: tensor.into_ir(),
1686            out: out.to_ir_out(),
1687        };
1688        out.client.register(
1689            streams,
1690            OperationIr::Float(dtype, FloatOperationIr::Sqrt(desc.clone())),
1691            SqrtOps::<B>::new(desc),
1692        );
1693
1694        out
1695    }
1696
1697    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1698        unary_float_ops!(AbsOps, B::float_abs);
1699
1700        let mut streams = OperationStreams::default();
1701        streams.tensor(&tensor);
1702        let dtype = tensor.dtype;
1703        let out = tensor
1704            .client
1705            .tensor_uninitialized(tensor.shape.clone(), dtype);
1706
1707        let desc = UnaryOpIr {
1708            input: tensor.into_ir(),
1709            out: out.to_ir_out(),
1710        };
1711        out.client.register(
1712            streams,
1713            OperationIr::NumericFloat(dtype, NumericOperationIr::Abs(desc.clone())),
1714            AbsOps::<B>::new(desc),
1715        );
1716
1717        out
1718    }
1719
1720    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1721        unary_float_ops!(CosOps, B::float_cos);
1722
1723        let mut streams = OperationStreams::default();
1724        streams.tensor(&tensor);
1725        let dtype = tensor.dtype;
1726        let out = tensor
1727            .client
1728            .tensor_uninitialized(tensor.shape.clone(), dtype);
1729
1730        let desc = UnaryOpIr {
1731            input: tensor.into_ir(),
1732            out: out.to_ir_out(),
1733        };
1734        out.client.register(
1735            streams,
1736            OperationIr::Float(dtype, FloatOperationIr::Cos(desc.clone())),
1737            CosOps::<B>::new(desc),
1738        );
1739
1740        out
1741    }
1742
1743    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1744        unary_float_ops!(SinOps, B::float_sin);
1745
1746        let mut streams = OperationStreams::default();
1747        streams.tensor(&tensor);
1748        let dtype = tensor.dtype;
1749        let out = tensor
1750            .client
1751            .tensor_uninitialized(tensor.shape.clone(), dtype);
1752
1753        let desc = UnaryOpIr {
1754            input: tensor.into_ir(),
1755            out: out.to_ir_out(),
1756        };
1757        out.client.register(
1758            streams,
1759            OperationIr::Float(dtype, FloatOperationIr::Sin(desc.clone())),
1760            SinOps::<B>::new(desc),
1761        );
1762
1763        out
1764    }
1765
1766    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1767        unary_float_ops!(TanhOps, B::float_tanh);
1768
1769        let mut streams = OperationStreams::default();
1770        streams.tensor(&tensor);
1771        let dtype = tensor.dtype;
1772        let out = tensor
1773            .client
1774            .tensor_uninitialized(tensor.shape.clone(), dtype);
1775
1776        let desc = UnaryOpIr {
1777            input: tensor.into_ir(),
1778            out: out.to_ir_out(),
1779        };
1780        out.client.register(
1781            streams,
1782            OperationIr::Float(dtype, FloatOperationIr::Tanh(desc.clone())),
1783            TanhOps::<B>::new(desc),
1784        );
1785
1786        out
1787    }
1788
1789    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1790        unary_float_ops!(Recip, B::float_recip);
1791
1792        let mut streams = OperationStreams::default();
1793        streams.tensor(&tensor);
1794        let dtype = tensor.dtype;
1795        let out = tensor
1796            .client
1797            .tensor_uninitialized(tensor.shape.clone(), dtype);
1798        let desc = UnaryOpIr {
1799            input: tensor.into_ir(),
1800            out: out.to_ir_out(),
1801        };
1802        out.client.register(
1803            streams,
1804            OperationIr::Float(dtype, FloatOperationIr::Recip(desc.clone())),
1805            Recip::<B>::new(desc),
1806        );
1807
1808        out
1809    }
1810
1811    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1812        unary_float_ops!(TanhOps, B::float_erf);
1813
1814        let mut streams = OperationStreams::default();
1815        streams.tensor(&tensor);
1816        let dtype = tensor.dtype;
1817        let out = tensor
1818            .client
1819            .tensor_uninitialized(tensor.shape.clone(), dtype);
1820
1821        let desc = UnaryOpIr {
1822            input: tensor.into_ir(),
1823            out: out.to_ir_out(),
1824        };
1825        out.client.register(
1826            streams,
1827            OperationIr::Float(dtype, FloatOperationIr::Erf(desc.clone())),
1828            TanhOps::<B>::new(desc),
1829        );
1830
1831        out
1832    }
1833
1834    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1835        #[derive(new, Debug)]
1836        struct CatOps<B: FusionBackend> {
1837            desc: CatOpIr,
1838            _b: PhantomData<B>,
1839        }
1840
1841        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
1842            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1843                let tensors = self
1844                    .desc
1845                    .tensors
1846                    .iter()
1847                    .map(|tensor| handles.get_float_tensor::<B>(tensor))
1848                    .collect();
1849
1850                let output = B::float_cat(tensors, self.desc.dim);
1851
1852                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1853            }
1854        }
1855
1856        let tensor_first = tensors.first().unwrap();
1857        let dtype = tensor_first.dtype;
1858        let client = tensor_first.client.clone();
1859
1860        let mut streams = OperationStreams::default();
1861        tensors.iter().for_each(|tensor| streams.tensor(tensor));
1862
1863        // Calculate the output shape
1864        let shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap();
1865
1866        let out = client.tensor_uninitialized(shape, dtype);
1867
1868        let desc = CatOpIr {
1869            tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
1870            dim,
1871            out: out.to_ir_out(),
1872        };
1873        desc.tensors
1874            .windows(2)
1875            .for_each(|desc| check_binary_op_types(&desc[0], &desc[1]).unwrap());
1876        client.register(
1877            streams,
1878            OperationIr::BaseFloat(BaseOperationIr::Cat(desc.clone())),
1879            CatOps::<B>::new(desc),
1880        );
1881
1882        out
1883    }
1884
1885    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1886        reduce_float2int_ops!(ArgMaxOps, B::float_argmax);
1887
1888        let mut streams = OperationStreams::default();
1889        streams.tensor(&tensor);
1890        let dtype = tensor.dtype;
1891        let mut shape = tensor.shape.clone();
1892        shape[dim] = 1;
1893        let out = tensor
1894            .client
1895            .tensor_uninitialized(shape, B::IntElem::dtype());
1896
1897        let desc = ReduceDimOpIr {
1898            input: tensor.into_ir(),
1899            axis: dim,
1900            out: out.to_ir_out(),
1901        };
1902        out.client.register(
1903            streams,
1904            OperationIr::NumericFloat(dtype, NumericOperationIr::ArgMax(desc.clone())),
1905            ArgMaxOps::<B>::new(desc),
1906        );
1907
1908        out
1909    }
1910
1911    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1912        #[derive(new, Debug)]
1913        struct RepeatDimOps<B: FusionBackend> {
1914            desc: RepeatDimOpIr,
1915            _b: PhantomData<B>,
1916        }
1917
1918        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1919            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1920                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
1921
1922                let output = B::float_repeat_dim(tensor, self.desc.dim, self.desc.times);
1923
1924                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1925            }
1926        }
1927
1928        let mut streams = OperationStreams::default();
1929        streams.tensor(&tensor);
1930        let shape = tensor.shape.clone().repeat(dim, times);
1931        let out = tensor.client.tensor_uninitialized(shape, tensor.dtype);
1932
1933        let desc = RepeatDimOpIr {
1934            tensor: tensor.into_ir(),
1935            dim,
1936            times,
1937            out: out.to_ir_out(),
1938        };
1939        out.client.register(
1940            streams,
1941            OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc.clone())),
1942            RepeatDimOps::<B>::new(desc),
1943        );
1944
1945        out
1946    }
1947
1948    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1949        reduce_float2int_ops!(ArgMinOps, B::float_argmin);
1950
1951        let mut streams = OperationStreams::default();
1952        streams.tensor(&tensor);
1953        let mut shape = tensor.shape.clone();
1954        shape[dim] = 1;
1955        let dtype = tensor.dtype;
1956        let out = tensor
1957            .client
1958            .tensor_uninitialized(shape, B::IntElem::dtype());
1959
1960        let desc = ReduceDimOpIr {
1961            input: tensor.into_ir(),
1962            axis: dim,
1963            out: out.to_ir_out(),
1964        };
1965        out.client.register(
1966            streams,
1967            OperationIr::NumericFloat(dtype, NumericOperationIr::ArgMin(desc.clone())),
1968            ArgMinOps::<B>::new(desc),
1969        );
1970
1971        out
1972    }
1973
1974    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1975        unary_float_ops!(MaxOps, B::float_max, reduce);
1976
1977        let mut streams = OperationStreams::default();
1978        streams.tensor(&tensor);
1979        let dtype = tensor.dtype;
1980        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1981
1982        let desc = UnaryOpIr {
1983            input: tensor.into_ir(),
1984            out: out.to_ir_out(),
1985        };
1986        out.client.register(
1987            streams,
1988            OperationIr::NumericFloat(dtype, NumericOperationIr::Max(desc.clone())),
1989            MaxOps::<B>::new(desc),
1990        );
1991
1992        out
1993    }
1994
1995    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1996        reduce_float_ops!(MaxDimOps, B::float_max_dim);
1997
1998        let mut streams = OperationStreams::default();
1999        streams.tensor(&tensor);
2000        let mut shape = tensor.shape.clone();
2001        let dtype = tensor.dtype;
2002        shape[dim] = 1;
2003        let out = tensor.client.tensor_uninitialized(shape, dtype);
2004
2005        let desc = ReduceDimOpIr {
2006            input: tensor.into_ir(),
2007            axis: dim,
2008            out: out.to_ir_out(),
2009        };
2010        out.client.register(
2011            streams,
2012            OperationIr::NumericFloat(dtype, NumericOperationIr::MaxDim(desc.clone())),
2013            MaxDimOps::<B>::new(desc),
2014        );
2015
2016        out
2017    }
2018
2019    fn float_max_dim_with_indices(
2020        tensor: FloatTensor<Self>,
2021        dim: usize,
2022    ) -> (FloatTensor<Self>, IntTensor<Self>) {
2023        #[derive(new, Debug)]
2024        struct MaxDimWithIndicesOps<B: FusionBackend> {
2025            desc: ReduceDimWithIndicesOpIr,
2026            _b: PhantomData<B>,
2027        }
2028
2029        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
2030            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2031                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2032                let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim);
2033
2034                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2035                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2036            }
2037        }
2038
2039        let mut streams = OperationStreams::default();
2040        streams.tensor(&tensor);
2041
2042        let mut shape = tensor.shape.clone();
2043        shape[dim] = 1;
2044        let dtype = tensor.dtype;
2045        let client = tensor.client.clone();
2046        let out = client.tensor_uninitialized(shape.clone(), dtype);
2047        let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype());
2048
2049        let desc = ReduceDimWithIndicesOpIr {
2050            tensor: tensor.into_ir(),
2051            dim,
2052            out: out.to_ir_out(),
2053            out_indices: out_indices.to_ir_out(),
2054        };
2055        client.register(
2056            streams,
2057            OperationIr::NumericFloat(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())),
2058            MaxDimWithIndicesOps::<B>::new(desc),
2059        );
2060
2061        (out, out_indices)
2062    }
2063
2064    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2065        unary_float_ops!(MinOps, B::float_min, reduce);
2066
2067        let mut streams = OperationStreams::default();
2068        streams.tensor(&tensor);
2069        let dtype = tensor.dtype;
2070        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
2071
2072        let desc = UnaryOpIr {
2073            input: tensor.into_ir(),
2074            out: out.to_ir_out(),
2075        };
2076        out.client.register(
2077            streams,
2078            OperationIr::NumericFloat(dtype, NumericOperationIr::Min(desc.clone())),
2079            MinOps::<B>::new(desc),
2080        );
2081
2082        out
2083    }
2084
2085    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2086        reduce_float_ops!(MinDimOps, B::float_min_dim);
2087
2088        let mut streams = OperationStreams::default();
2089        streams.tensor(&tensor);
2090        let dtype = tensor.dtype;
2091        let mut shape = tensor.shape.clone();
2092        shape[dim] = 1;
2093        let out = tensor.client.tensor_uninitialized(shape, dtype);
2094
2095        let desc = ReduceDimOpIr {
2096            input: tensor.into_ir(),
2097            axis: dim,
2098            out: out.to_ir_out(),
2099        };
2100        out.client.register(
2101            streams,
2102            OperationIr::NumericFloat(dtype, NumericOperationIr::MinDim(desc.clone())),
2103            MinDimOps::<B>::new(desc),
2104        );
2105
2106        out
2107    }
2108
2109    fn float_min_dim_with_indices(
2110        tensor: FloatTensor<Self>,
2111        dim: usize,
2112    ) -> (FloatTensor<Self>, IntTensor<Self>) {
2113        #[derive(new, Debug)]
2114        struct MinDimWithIndicesOps<B: FusionBackend> {
2115            desc: ReduceDimWithIndicesOpIr,
2116            _b: PhantomData<B>,
2117        }
2118
2119        impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
2120            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2121                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2122                let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim);
2123
2124                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2125                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2126            }
2127        }
2128
2129        let mut streams = OperationStreams::default();
2130        streams.tensor(&tensor);
2131        let dtype = tensor.dtype;
2132        let mut shape = tensor.shape.clone();
2133        shape[dim] = 1;
2134        let client = tensor.client.clone();
2135        let out = client.tensor_uninitialized(shape.clone(), dtype);
2136        let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype());
2137
2138        let desc = ReduceDimWithIndicesOpIr {
2139            tensor: tensor.into_ir(),
2140            dim,
2141            out: out.to_ir_out(),
2142            out_indices: out_indices.to_ir_out(),
2143        };
2144        client.register(
2145            streams,
2146            OperationIr::NumericFloat(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())),
2147            MinDimWithIndicesOps::<B>::new(desc),
2148        );
2149
2150        (out, out_indices)
2151    }
2152
2153    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2154        unary_float_ops!(MaxAbsOps, B::float_max_abs, reduce);
2155
2156        let mut streams = OperationStreams::default();
2157        streams.tensor(&tensor);
2158        let dtype = tensor.dtype;
2159        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
2160
2161        let desc = UnaryOpIr {
2162            input: tensor.into_ir(),
2163            out: out.to_ir_out(),
2164        };
2165        out.client.register(
2166            streams,
2167            OperationIr::NumericFloat(dtype, NumericOperationIr::MaxAbs(desc.clone())),
2168            MaxAbsOps::<B>::new(desc),
2169        );
2170
2171        out
2172    }
2173
2174    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2175        reduce_float_ops!(MaxAbsDimOps, B::float_max_abs_dim);
2176
2177        let mut streams = OperationStreams::default();
2178        streams.tensor(&tensor);
2179        let mut shape = tensor.shape.clone();
2180        let dtype = tensor.dtype;
2181        shape[dim] = 1;
2182        let out = tensor.client.tensor_uninitialized(shape, dtype);
2183
2184        let desc = ReduceDimOpIr {
2185            input: tensor.into_ir(),
2186            axis: dim,
2187            out: out.to_ir_out(),
2188        };
2189        out.client.register(
2190            streams,
2191            OperationIr::NumericFloat(dtype, NumericOperationIr::MaxAbsDim(desc.clone())),
2192            MaxAbsDimOps::<B>::new(desc),
2193        );
2194
2195        out
2196    }
2197
2198    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
2199        binary_float_ops!(PowOps, B::float_powf);
2200        let mut streams = OperationStreams::default();
2201        streams.tensor(&lhs);
2202        streams.tensor(&rhs);
2203        let dtype = lhs.dtype;
2204
2205        let out = lhs
2206            .client
2207            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2208
2209        let desc = BinaryOpIr {
2210            lhs: lhs.into_ir(),
2211            rhs: rhs.into_ir(),
2212            out: out.to_ir_out(),
2213        };
2214        out.client.register(
2215            streams,
2216            OperationIr::NumericFloat(dtype, NumericOperationIr::Powf(desc.clone())),
2217            PowOps::<B>::new(desc),
2218        );
2219
2220        out
2221    }
2222
2223    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2224        #[derive(new, Debug)]
2225        struct PermuteDimsOps<B: FusionBackend> {
2226            desc: PermuteOpIr,
2227            _b: PhantomData<B>,
2228        }
2229
2230        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
2231            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2232                let input = handles.get_float_tensor::<B>(&self.desc.input);
2233                let output = B::float_permute(input, self.desc.axes.as_slice());
2234                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2235            }
2236        }
2237
2238        let mut streams = OperationStreams::default();
2239        streams.tensor(&tensor);
2240
2241        // Change the shape of the tensor to match the new axes
2242        let shape = tensor.shape.clone().permute(axes).unwrap();
2243        let out = tensor.client.tensor_uninitialized(shape, tensor.dtype);
2244
2245        let desc = PermuteOpIr {
2246            input: tensor.into_ir(),
2247            axes: axes.to_vec(),
2248            out: out.to_ir_out(),
2249        };
2250
2251        out.client.register(
2252            streams,
2253            OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
2254            PermuteDimsOps::<B>::new(desc),
2255        );
2256
2257        out
2258    }
2259
2260    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
2261        #[derive(new, Debug)]
2262        struct ExpandOps<B: FusionBackend> {
2263            desc: ExpandOpIr,
2264            _b: PhantomData<B>,
2265        }
2266
2267        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
2268            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2269                let input = handles.get_float_tensor::<B>(&self.desc.input);
2270                let output = B::float_expand(input, self.desc.shape.clone());
2271
2272                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2273            }
2274        }
2275
2276        let mut streams = OperationStreams::default();
2277        streams.tensor(&tensor);
2278
2279        let out = tensor
2280            .client
2281            .tensor_uninitialized(shape.clone(), tensor.dtype);
2282
2283        let desc = ExpandOpIr {
2284            input: tensor.into_ir(),
2285            shape,
2286            out: out.to_ir_out(),
2287        };
2288
2289        out.client.register(
2290            streams,
2291            OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
2292            ExpandOps::<B>::new(desc),
2293        );
2294
2295        out
2296    }
2297
2298    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2299        #[derive(new, Debug)]
2300        struct FlipOps<B: FusionBackend> {
2301            desc: FlipOpIr,
2302            _b: PhantomData<B>,
2303        }
2304
2305        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
2306            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2307                let input = handles.get_float_tensor::<B>(&self.desc.input);
2308                let output = B::float_flip(input, &self.desc.axes);
2309                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2310            }
2311        }
2312
2313        let mut streams = OperationStreams::default();
2314        streams.tensor(&tensor);
2315        let out = tensor
2316            .client
2317            .tensor_uninitialized(tensor.shape.clone(), tensor.dtype);
2318
2319        let desc = FlipOpIr {
2320            input: tensor.into_ir(),
2321            axes: axes.to_vec(),
2322            out: out.to_ir_out(),
2323        };
2324
2325        out.client.register(
2326            streams,
2327            OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
2328            FlipOps::<B>::new(desc),
2329        );
2330
2331        out
2332    }
2333
2334    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2335        unary_float_ops!(RoundOps, B::float_round);
2336
2337        let mut streams = OperationStreams::default();
2338        streams.tensor(&tensor);
2339        let dtype = tensor.dtype;
2340        let out = tensor
2341            .client
2342            .tensor_uninitialized(tensor.shape.clone(), dtype);
2343
2344        let desc = UnaryOpIr {
2345            input: tensor.into_ir(),
2346            out: out.to_ir_out(),
2347        };
2348        out.client.register(
2349            streams,
2350            OperationIr::Float(dtype, FloatOperationIr::Round(desc.clone())),
2351            RoundOps::<B>::new(desc),
2352        );
2353
2354        out
2355    }
2356
2357    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2358        unary_float_ops!(FloorOps, B::float_floor);
2359
2360        let mut streams = OperationStreams::default();
2361        streams.tensor(&tensor);
2362        let dtype = tensor.dtype;
2363        let out = tensor
2364            .client
2365            .tensor_uninitialized(tensor.shape.clone(), dtype);
2366
2367        let desc = UnaryOpIr {
2368            input: tensor.into_ir(),
2369            out: out.to_ir_out(),
2370        };
2371        out.client.register(
2372            streams,
2373            OperationIr::Float(dtype, FloatOperationIr::Floor(desc.clone())),
2374            FloorOps::<B>::new(desc),
2375        );
2376
2377        out
2378    }
2379
2380    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2381        unary_float_ops!(CeilOps, B::float_ceil);
2382
2383        let mut streams = OperationStreams::default();
2384        streams.tensor(&tensor);
2385        let dtype = tensor.dtype;
2386        let out = tensor
2387            .client
2388            .tensor_uninitialized(tensor.shape.clone(), dtype);
2389
2390        let desc = UnaryOpIr {
2391            input: tensor.into_ir(),
2392            out: out.to_ir_out(),
2393        };
2394        out.client.register(
2395            streams,
2396            OperationIr::Float(dtype, FloatOperationIr::Ceil(desc.clone())),
2397            CeilOps::<B>::new(desc),
2398        );
2399
2400        out
2401    }
2402
2403    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2404        unary_float_ops!(TruncOps, B::float_trunc);
2405
2406        let mut streams = OperationStreams::default();
2407        streams.tensor(&tensor);
2408        let dtype = tensor.dtype;
2409        let out = tensor
2410            .client
2411            .tensor_uninitialized(tensor.shape.clone(), dtype);
2412
2413        let desc = UnaryOpIr {
2414            input: tensor.into_ir(),
2415            out: out.to_ir_out(),
2416        };
2417        out.client.register(
2418            streams,
2419            OperationIr::Float(dtype, FloatOperationIr::Trunc(desc.clone())),
2420            TruncOps::<B>::new(desc),
2421        );
2422
2423        out
2424    }
2425
2426    fn float_cast(tensor: FloatTensor<Self>, dtype: burn_tensor::FloatDType) -> FloatTensor<Self> {
2427        #[derive(new, Debug)]
2428        struct CastOps<B: FusionBackend> {
2429            desc: UnaryOpIr,
2430            dtype: burn_tensor::FloatDType,
2431            _b: PhantomData<B>,
2432        }
2433
2434        impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2435            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2436                let input = handles.get_float_tensor::<B>(&self.desc.input);
2437                let output: B::FloatTensorPrimitive = B::float_cast(input, self.dtype);
2438                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2439            }
2440        }
2441
2442        let mut streams = OperationStreams::default();
2443        streams.tensor(&tensor);
2444        let out = tensor
2445            .client
2446            .tensor_uninitialized(tensor.shape.clone(), dtype.into());
2447
2448        let desc = UnaryOpIr {
2449            input: tensor.into_ir(),
2450            out: out.to_ir_out(),
2451        };
2452        out.client.register(
2453            streams,
2454            OperationIr::BaseFloat(BaseOperationIr::Cast(desc.clone())),
2455            CastOps::<B>::new(desc, dtype),
2456        );
2457
2458        out
2459    }
2460
2461    fn float_unfold(
2462        tensor: FloatTensor<Self>,
2463        dim: usize,
2464        size: usize,
2465        step: usize,
2466    ) -> FloatTensor<Self> {
2467        #[derive(new, Debug)]
2468        struct UnfoldOps<B: FusionBackend> {
2469            desc: UnfoldOpIr,
2470            _b: PhantomData<B>,
2471        }
2472
2473        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2474            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2475                let input = handles.get_float_tensor::<B>(&self.desc.input);
2476                let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2477
2478                handles.register_float_tensor::<B>(&self.desc.out.id, output);
2479            }
2480        }
2481
2482        let mut streams = OperationStreams::default();
2483        streams.tensor(&tensor);
2484
2485        let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
2486
2487        let out = tensor
2488            .client
2489            .tensor_uninitialized(Shape::from(shape), tensor.dtype);
2490
2491        let desc = UnfoldOpIr {
2492            input: tensor.into_ir(),
2493            out: out.to_ir_out(),
2494            dim,
2495            size,
2496            step,
2497        };
2498
2499        out.client.register(
2500            streams,
2501            OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())),
2502            UnfoldOps::<B>::new(desc),
2503        );
2504
2505        out
2506    }
2507
2508    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
2509        #[derive(new, Debug)]
2510        struct IsNanOps<B: FusionBackend> {
2511            desc: UnaryOpIr,
2512            _b: PhantomData<B>,
2513        }
2514        impl<B: FusionBackend> Operation<B::FusionRuntime> for IsNanOps<B> {
2515            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2516                let input = handles.get_float_tensor::<B>(&self.desc.input);
2517                let output = B::float_is_nan(input);
2518                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2519            }
2520        }
2521
2522        let mut streams = OperationStreams::default();
2523        streams.tensor(&tensor);
2524        let out = tensor
2525            .client
2526            .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
2527
2528        let dtype = tensor.dtype;
2529        let desc = UnaryOpIr {
2530            input: tensor.into_ir(),
2531            out: out.to_ir_out(),
2532        };
2533        out.client.register(
2534            streams,
2535            OperationIr::Float(dtype, FloatOperationIr::IsNan(desc.clone())),
2536            IsNanOps::<B>::new(desc),
2537        );
2538
2539        out
2540    }
2541
2542    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
2543        #[derive(new, Debug)]
2544        struct IsInfOps<B: FusionBackend> {
2545            desc: UnaryOpIr,
2546            _b: PhantomData<B>,
2547        }
2548        impl<B: FusionBackend> Operation<B::FusionRuntime> for IsInfOps<B> {
2549            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2550                let input = handles.get_float_tensor::<B>(&self.desc.input);
2551                let output = B::float_is_inf(input);
2552                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2553            }
2554        }
2555
2556        let mut streams = OperationStreams::default();
2557        streams.tensor(&tensor);
2558        let out = tensor
2559            .client
2560            .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
2561
2562        let dtype = tensor.dtype;
2563        let desc = UnaryOpIr {
2564            input: tensor.into_ir(),
2565            out: out.to_ir_out(),
2566        };
2567        out.client.register(
2568            streams,
2569            OperationIr::Float(dtype, FloatOperationIr::IsInf(desc.clone())),
2570            IsInfOps::<B>::new(desc),
2571        );
2572
2573        out
2574    }
2575}