1use super::NoOp;
2use crate::{
3 Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops,
4 client::GlobalFusionClient,
5 get_client, reduce_float_ops, reduce_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
6 stream::{OperationStreams, execution::Operation},
7 unary_float_ops,
8};
9use burn_backend::{
10 BoolDType, Distribution, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice,
11 TensorData,
12 ops::{FloatTensorOps, GridSampleOptions},
13 tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
14};
15use burn_ir::*;
16use std::marker::PhantomData;
17
18impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
19 #[cfg_attr(feature = "tracing", tracing::instrument(
20 level="trace",
21 skip(data),
22 fields(?data.shape, ?data.dtype)
23 ))]
24 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
25 let client = get_client::<B>(device);
26 let dtype = data.dtype;
27 let tensor = B::float_from_data(data, device);
28 let shape = burn_backend::TensorMetadata::shape(&tensor);
29
30 let handle = B::float_tensor_handle(tensor);
31 let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
32
33 client
34 .register(
35 OperationStreams::default(),
36 OperationIr::Init(desc),
37 NoOp::<B>::new(),
38 )
39 .output()
40 }
41
42 fn float_random(
43 shape: Shape,
44 distribution: Distribution,
45 device: &Device<Self>,
46 dtype: FloatDType,
47 ) -> FloatTensor<Self> {
48 #[derive(new, Debug)]
49 struct RandomOps<B: FusionBackend> {
50 desc: RandomOpIr,
51 device: Device<B>,
52 }
53
54 impl<B: FusionBackend> Operation<B::FusionRuntime> for RandomOps<B> {
55 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
56 let output: B::FloatTensorPrimitive = B::float_random(
57 self.desc.out.shape.clone(),
58 self.desc.distribution,
59 &self.device,
60 self.desc.out.dtype.into(),
61 );
62 handles.register_float_tensor::<B>(&self.desc.out.id, output);
63 }
64 }
65
66 let dtype = dtype.into();
67 let client = get_client::<B>(device);
68 let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
69
70 client
71 .register(
72 OperationStreams::default(),
73 OperationIr::Float(dtype, FloatOperationIr::Random(desc.clone())),
74 RandomOps::<B>::new(desc, device.clone()),
75 )
76 .output()
77 }
78
79 fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
80 #[derive(new, Debug)]
81 struct ZerosOps<B: FusionBackend> {
82 out: TensorIr,
83 device: Device<B>,
84 }
85
86 impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
87 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
88 let shape = self.out.shape.clone();
89 let output = B::float_zeros(shape, &self.device, self.out.dtype.into());
90 handles.register_float_tensor::<B>(&self.out.id, output);
91 }
92 }
93
94 let client = get_client::<B>(device);
95 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
96
97 client
98 .register(
99 OperationStreams::default(),
100 OperationIr::BaseFloat(BaseOperationIr::Zeros(desc.clone())),
101 ZerosOps::<B>::new(desc.out, device.clone()),
102 )
103 .output()
104 }
105
106 fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
107 #[derive(new, Debug)]
108 struct OnesOps<B: FusionBackend> {
109 out: TensorIr,
110 device: Device<B>,
111 }
112
113 impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
114 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
115 let shape = self.out.shape.clone();
116 let output = B::float_ones(shape, &self.device, self.out.dtype.into());
117 handles.register_float_tensor::<B>(&self.out.id, output);
118 }
119 }
120
121 let client = get_client::<B>(device);
122 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
123
124 client
125 .register(
126 OperationStreams::default(),
127 OperationIr::BaseFloat(BaseOperationIr::Ones(desc.clone())),
128 OnesOps::<B>::new(desc.out, device.clone()),
129 )
130 .output()
131 }
132
133 fn float_full(
134 shape: Shape,
135 fill_value: Scalar,
136 device: &Device<Self>,
137 dtype: FloatDType,
138 ) -> FloatTensor<Self> {
139 #[derive(new, Debug)]
140 struct FullOps<B: FusionBackend> {
141 out: TensorIr,
142 elem: ScalarIr,
143 device: Device<B>,
144 }
145
146 impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
147 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
148 let shape = self.out.shape.clone();
149 let dtype = self.out.dtype.into();
150 let output: B::FloatTensorPrimitive =
151 B::float_full(shape, self.elem.into(), &self.device, dtype);
152 handles.register_float_tensor::<B>(&self.out.id, output);
153 }
154 }
155
156 let dtype = dtype.into();
157 let client = get_client::<B>(device);
158 let value = fill_value.into();
159 let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
160
161 client
162 .register(
163 OperationStreams::default(),
164 OperationIr::NumericFloat(dtype, NumericOperationIr::Full(desc.clone())),
165 FullOps::<B>::new(desc.out, desc.value, device.clone()),
166 )
167 .output()
168 }
169
170 #[cfg_attr(feature = "tracing", tracing::instrument(
171 level="trace",
172 skip(tensor),
173 fields(
174 from = ?tensor.client.device(),
175 shape = ?tensor.shape,
176 dtype = ?tensor.dtype
177 )
178 ))]
179 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
180 tensor.into_data::<B>().await
181 }
182
183 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
184 tensor.client.device().clone()
185 }
186
187 #[cfg_attr(feature = "tracing", tracing::instrument(
188 level="trace",
189 skip(tensor),
190 fields(
191 from = ?tensor.client.device(),
192 shape = ?tensor.shape,
193 dtype = ?tensor.dtype,
194 )
195 ))]
196 fn float_to_device(tensor: FloatTensor<Self>, device_dst: &Device<Self>) -> FloatTensor<Self> {
197 let device_src: &B::Device = tensor.client.device();
198
199 if device_src == device_dst {
200 return tensor;
201 }
202
203 let id = tensor.stream;
204 let client_dst = get_client::<B>(device_dst);
205 let client_src = tensor.client.clone();
206
207 GlobalFusionClient::change_client_float::<B>(tensor.into_ir(), client_src, client_dst, id)
208 }
209
210 fn float_into_int(tensor: FloatTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
211 #[derive(new, Debug)]
212 struct IntoIntOps<B: FusionBackend> {
213 desc: CastOpIr,
214 _b: PhantomData<B>,
215 }
216
217 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
218 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
219 let input = handles.get_float_tensor::<B>(&self.desc.input);
220 let output = B::float_into_int(input, self.desc.out.dtype.into());
221
222 handles.register_int_tensor::<B>(&self.desc.out.id, output);
223 }
224 }
225
226 let streams = OperationStreams::with_inputs([&tensor]);
227
228 let client = tensor.client.clone();
229 let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
230 client.create_empty_handle()
231 });
232
233 client
234 .register(
235 streams,
236 OperationIr::Float(desc.input.dtype, FloatOperationIr::IntoInt(desc.clone())),
237 IntoIntOps::<B>::new(desc),
238 )
239 .output()
240 }
241
242 fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
243 #[derive(new, Debug)]
244 struct EmptyOps<B: FusionBackend> {
245 desc: TensorIr,
246 device: Device<B>,
247 }
248
249 impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
250 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
251 let output = B::float_empty(
252 self.desc.shape.clone(),
253 &self.device,
254 self.desc.dtype.into(),
255 );
256 handles.register_float_tensor::<B>(&self.desc.id, output);
257 }
258 }
259
260 let client = get_client::<B>(device);
261 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
262
263 client
264 .register(
265 OperationStreams::default(),
266 OperationIr::BaseFloat(BaseOperationIr::Empty(desc.clone())),
267 EmptyOps::<B>::new(desc.out, device.clone()),
268 )
269 .output()
270 }
271
272 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
273 binary_float_ops!(AddOps, B::float_add);
274
275 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
276
277 let client = lhs.client.clone();
278 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
279 client.create_empty_handle()
280 });
281
282 client
283 .register(
284 streams,
285 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Add(desc.clone())),
286 AddOps::<B>::new(desc),
287 )
288 .output()
289 }
290
291 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
292 scalar_float_ops!(AddOps, B::float_add_scalar);
293
294 let streams = OperationStreams::with_inputs([&lhs]);
295
296 let client = lhs.client.clone();
297 let rhs = rhs.into();
298 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
299
300 client
301 .register(
302 streams,
303 OperationIr::NumericFloat(
304 desc.out.dtype,
305 NumericOperationIr::AddScalar(desc.clone()),
306 ),
307 AddOps::<B>::new(desc),
308 )
309 .output()
310 }
311
312 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
313 #[derive(new, Debug)]
314 struct ClampOps<B: FusionBackend> {
315 desc: ClampOpIr,
316 _b: PhantomData<B>,
317 }
318
319 impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
320 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
321 let input = handles.get_float_tensor::<B>(&self.desc.tensor);
322 let output = B::float_clamp(input, self.desc.min.into(), self.desc.max.into());
323
324 handles.register_float_tensor::<B>(&self.desc.out.id, output);
325 }
326 }
327
328 let streams = OperationStreams::with_inputs([&tensor]);
329
330 let client = tensor.client.clone();
331 let min = min.into();
332 let max = max.into();
333 let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
334
335 client
336 .register(
337 streams,
338 OperationIr::NumericFloat(
339 desc.tensor.dtype,
340 NumericOperationIr::Clamp(desc.clone()),
341 ),
342 ClampOps::<B>::new(desc),
343 )
344 .output()
345 }
346
347 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
348 binary_float_ops!(SubOps, B::float_sub);
349
350 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
351
352 let client = lhs.client.clone();
353 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
354 client.create_empty_handle()
355 });
356
357 client
358 .register(
359 streams,
360 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sub(desc.clone())),
361 SubOps::<B>::new(desc),
362 )
363 .output()
364 }
365
366 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
367 scalar_float_ops!(SubOps, B::float_sub_scalar);
368
369 let streams = OperationStreams::with_inputs([&lhs]);
370
371 let client = lhs.client.clone();
372 let rhs = rhs.into();
373 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
374
375 client
376 .register(
377 streams,
378 OperationIr::NumericFloat(
379 desc.out.dtype,
380 NumericOperationIr::SubScalar(desc.clone()),
381 ),
382 SubOps::<B>::new(desc),
383 )
384 .output()
385 }
386
387 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
388 binary_float_ops!(MulOps, B::float_mul);
389
390 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
391
392 let client = lhs.client.clone();
393 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
394 client.create_empty_handle()
395 });
396
397 client
398 .register(
399 streams,
400 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mul(desc.clone())),
401 MulOps::<B>::new(desc),
402 )
403 .output()
404 }
405
406 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
407 scalar_float_ops!(MulOps, B::float_mul_scalar);
408
409 let streams = OperationStreams::with_inputs([&lhs]);
410
411 let client = lhs.client.clone();
412 let rhs = rhs.into();
413 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
414
415 client
416 .register(
417 streams,
418 OperationIr::NumericFloat(
419 desc.out.dtype,
420 NumericOperationIr::MulScalar(desc.clone()),
421 ),
422 MulOps::<B>::new(desc),
423 )
424 .output()
425 }
426
427 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
428 binary_float_ops!(DivOps, B::float_div);
429
430 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
431
432 let client = lhs.client.clone();
433 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
434 client.create_empty_handle()
435 });
436
437 client
438 .register(
439 streams,
440 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Div(desc.clone())),
441 DivOps::<B>::new(desc),
442 )
443 .output()
444 }
445
446 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
447 scalar_float_ops!(DivOps, B::float_div_scalar);
448
449 let streams = OperationStreams::with_inputs([&lhs]);
450
451 let client = lhs.client.clone();
452 let rhs = rhs.into();
453 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
454
455 client
456 .register(
457 streams,
458 OperationIr::NumericFloat(
459 desc.out.dtype,
460 NumericOperationIr::DivScalar(desc.clone()),
461 ),
462 DivOps::<B>::new(desc),
463 )
464 .output()
465 }
466
467 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
468 binary_float_ops!(ModOps, B::float_remainder);
469
470 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
471
472 let client = lhs.client.clone();
473 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
474 client.create_empty_handle()
475 });
476
477 client
478 .register(
479 streams,
480 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Rem(desc.clone())),
481 ModOps::<B>::new(desc),
482 )
483 .output()
484 }
485
486 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
487 scalar_float_ops!(ModOps, B::float_remainder_scalar);
488
489 let streams = OperationStreams::with_inputs([&lhs]);
490
491 let client = lhs.client.clone();
492 let rhs = rhs.into();
493 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
494
495 client
496 .register(
497 streams,
498 OperationIr::NumericFloat(
499 desc.out.dtype,
500 NumericOperationIr::RemScalar(desc.clone()),
501 ),
502 ModOps::<B>::new(desc),
503 )
504 .output()
505 }
506
507 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
508 binary_float_ops!(MatmulOps, B::float_matmul);
509
510 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
511
512 let client = lhs.client.clone();
513 let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
514 client.create_empty_handle()
515 });
516
517 client
518 .register(
519 streams,
520 OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())),
521 MatmulOps::<B>::new(desc.into()),
522 )
523 .output()
524 }
525
526 fn float_cross(
527 lhs: FloatTensor<Self>,
528 rhs: FloatTensor<Self>,
529 dim: usize,
530 ) -> FloatTensor<Self> {
531 #[derive(new, Debug)]
532 struct CrossOps<B: FusionBackend> {
533 desc: CrossOpIr,
534 _b: PhantomData<B>,
535 }
536
537 impl<B: FusionBackend> Operation<B::FusionRuntime> for CrossOps<B> {
538 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
539 let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
540 let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
541 let output = B::float_cross(lhs, rhs, self.desc.dim);
542 handles.register_float_tensor::<B>(&self.desc.out.id, output);
543 }
544 }
545
546 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
547
548 let client = lhs.client.clone();
549 let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {
550 client.create_empty_handle()
551 });
552
553 client
554 .register(
555 streams,
556 OperationIr::Float(desc.out.dtype, FloatOperationIr::Cross(desc.clone())),
557 CrossOps::<B>::new(desc),
558 )
559 .output()
560 }
561
562 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
563 #[derive(new, Debug)]
564 struct SwapDimsOps<B: FusionBackend> {
565 desc: SwapDimsOpIr,
566 _b: PhantomData<B>,
567 }
568
569 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
570 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
571 let input = handles.get_float_tensor::<B>(&self.desc.input);
572 let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);
573 handles.register_float_tensor::<B>(&self.desc.out.id, output);
574 }
575 }
576
577 let streams = OperationStreams::with_inputs([&tensor]);
578
579 let client = tensor.client.clone();
580 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
581 client.create_empty_handle()
582 });
583
584 client
585 .register(
586 streams,
587 OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
588 SwapDimsOps::<B>::new(desc),
589 )
590 .output()
591 }
592
593 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
594 if tensor.shape == shape {
595 return tensor;
596 }
597
598 #[derive(new, Debug)]
599 struct ReshapeDimsOps<B: FusionBackend> {
600 desc: ShapeOpIr,
601 _b: PhantomData<B>,
602 }
603
604 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
605 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
606 let input = handles.get_float_tensor::<B>(&self.desc.input);
607 let output = B::float_reshape(input, self.desc.out.shape.clone());
608 handles.register_float_tensor::<B>(&self.desc.out.id, output);
609 }
610 }
611
612 let streams = OperationStreams::with_inputs([&tensor]);
613
614 let client = tensor.client.clone();
615 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
616
617 client
618 .register(
619 streams,
620 OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
621 ReshapeDimsOps::<B>::new(desc),
622 )
623 .output()
624 }
625
626 fn float_gather(
627 dim: usize,
628 tensor: FloatTensor<Self>,
629 indices: IntTensor<Self>,
630 ) -> FloatTensor<Self> {
631 #[derive(new, Debug)]
632 struct GatherOps<B: FusionBackend> {
633 desc: GatherOpIr,
634 _b: PhantomData<B>,
635 }
636
637 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
638 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
639 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
640 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
641
642 let output = B::float_gather(self.desc.dim, tensor, indices);
643 handles.register_float_tensor::<B>(&self.desc.out.id, output);
644 }
645 }
646
647 let streams = OperationStreams::with_inputs([&tensor, &indices]);
648
649 let client = tensor.client.clone();
650 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
651 client.create_empty_handle()
652 });
653
654 client
655 .register(
656 streams,
657 OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),
658 GatherOps::<B>::new(desc),
659 )
660 .output()
661 }
662
663 fn float_scatter_add(
664 dim: usize,
665 tensor: FloatTensor<Self>,
666 indices: IntTensor<Self>,
667 value: FloatTensor<Self>,
668 ) -> FloatTensor<Self> {
669 #[derive(new, Debug)]
670 struct ScatterOps<B: FusionBackend> {
671 desc: ScatterOpIr,
672 _b: PhantomData<B>,
673 }
674
675 impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
676 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
677 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
678 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
679 let value = handles.get_float_tensor::<B>(&self.desc.value);
680
681 let output = B::float_scatter_add(self.desc.dim, tensor, indices, value);
682
683 handles.register_float_tensor::<B>(&self.desc.out.id, output);
684 }
685 }
686
687 let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
688
689 let client = tensor.client.clone();
690 let desc = ScatterOpIr::create(
691 tensor.into_ir(),
692 dim,
693 indices.into_ir(),
694 value.into_ir(),
695 IndexingUpdateOp::Add,
696 || client.create_empty_handle(),
697 );
698
699 client
700 .register(
701 streams,
702 OperationIr::BaseFloat(BaseOperationIr::Scatter(desc.clone())),
703 ScatterOps::<B>::new(desc),
704 )
705 .output()
706 }
707
708 fn float_scatter_nd(
709 data: FloatTensor<Self>,
710 indices: IntTensor<Self>,
711 values: FloatTensor<Self>,
712 reduction: IndexingUpdateOp,
713 ) -> FloatTensor<Self> {
714 #[derive(new, Debug)]
715 struct ScatterNdOps<B: FusionBackend> {
716 desc: ScatterNdOpIr,
717 _b: PhantomData<B>,
718 }
719
720 impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterNdOps<B> {
721 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
722 let data = handles.get_float_tensor::<B>(&self.desc.data);
723 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
724 let values = handles.get_float_tensor::<B>(&self.desc.values);
725
726 let output = B::float_scatter_nd(data, indices, values, self.desc.reduction);
727
728 handles.register_float_tensor::<B>(&self.desc.out.id, output);
729 }
730 }
731
732 let streams = OperationStreams::with_inputs([&data, &indices, &values]);
733
734 let client = data.client.clone();
735 let desc = ScatterNdOpIr::create(
736 data.into_ir(),
737 indices.into_ir(),
738 values.into_ir(),
739 reduction,
740 || client.create_empty_handle(),
741 );
742
743 client
744 .register(
745 streams,
746 OperationIr::BaseFloat(BaseOperationIr::ScatterNd(desc.clone())),
747 ScatterNdOps::<B>::new(desc),
748 )
749 .output()
750 }
751
752 fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
753 #[derive(new, Debug)]
754 struct GatherNdOps<B: FusionBackend> {
755 desc: GatherNdOpIr,
756 _b: PhantomData<B>,
757 }
758
759 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherNdOps<B> {
760 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
761 let data = handles.get_float_tensor::<B>(&self.desc.data);
762 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
763
764 let output = B::float_gather_nd(data, indices);
765 handles.register_float_tensor::<B>(&self.desc.out.id, output);
766 }
767 }
768
769 let streams = OperationStreams::with_inputs([&data, &indices]);
770
771 let client = data.client.clone();
772 let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
773 client.create_empty_handle()
774 });
775
776 client
777 .register(
778 streams,
779 OperationIr::BaseFloat(BaseOperationIr::GatherNd(desc.clone())),
780 GatherNdOps::<B>::new(desc),
781 )
782 .output()
783 }
784
785 fn float_select(
786 tensor: FloatTensor<Self>,
787 dim: usize,
788 indices: IntTensor<Self>,
789 ) -> FloatTensor<Self> {
790 #[derive(new, Debug)]
791 struct SelectOps<B: FusionBackend> {
792 desc: SelectOpIr,
793 _b: PhantomData<B>,
794 }
795
796 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
797 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
798 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
799 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
800
801 let output = B::float_select(tensor, self.desc.dim, indices);
802
803 handles.register_float_tensor::<B>(&self.desc.out.id, output);
804 }
805 }
806
807 let streams = OperationStreams::with_inputs([&tensor, &indices]);
808
809 let client = tensor.client.clone();
810 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
811 client.create_empty_handle()
812 });
813
814 client
815 .register(
816 streams,
817 OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),
818 SelectOps::<B>::new(desc),
819 )
820 .output()
821 }
822
823 fn float_select_add(
824 tensor: FloatTensor<Self>,
825 dim: usize,
826 indices: IntTensor<Self>,
827 value: FloatTensor<Self>,
828 ) -> FloatTensor<Self> {
829 #[derive(new, Debug)]
830 struct SelectAssignOps<B: FusionBackend> {
831 desc: SelectAssignOpIr,
832 _b: PhantomData<B>,
833 }
834
835 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
836 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
837 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
838 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
839 let value = handles.get_float_tensor::<B>(&self.desc.value);
840
841 let output = B::float_select_add(tensor, self.desc.dim, indices, value);
842
843 handles.register_float_tensor::<B>(&self.desc.out.id, output);
844 }
845 }
846
847 let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
848
849 let client = tensor.client.clone();
850 let desc = SelectAssignOpIr::create(
851 tensor.into_ir(),
852 dim,
853 indices.into_ir(),
854 value.into_ir(),
855 IndexingUpdateOp::Add,
856 || client.create_empty_handle(),
857 );
858
859 client
860 .register(
861 streams,
862 OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc.clone())),
863 SelectAssignOps::<B>::new(desc),
864 )
865 .output()
866 }
867
868 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
869 #[derive(new, Debug)]
870 struct SliceOps<B: FusionBackend> {
871 desc: SliceOpIr,
872 _b: PhantomData<B>,
873 }
874
875 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
876 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
877 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
878
879 let output = B::float_slice(tensor, self.desc.ranges.as_slice());
880
881 handles.register_float_tensor::<B>(&self.desc.out.id, output);
882 }
883 }
884
885 let streams = OperationStreams::with_inputs([&tensor]);
886
887 let client = tensor.client.clone();
888 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
889 client.create_empty_handle()
890 });
891
892 client
893 .register(
894 streams,
895 OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
896 SliceOps::<B>::new(desc),
897 )
898 .output()
899 }
900
901 fn float_slice_assign(
902 tensor: FloatTensor<Self>,
903 slices: &[burn_backend::Slice],
904 value: FloatTensor<Self>,
905 ) -> FloatTensor<Self> {
906 #[derive(new, Debug)]
907 struct SliceAssignOps<B: FusionBackend> {
908 desc: SliceAssignOpIr,
909 _b: PhantomData<B>,
910 }
911
912 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
913 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
914 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
915 let value = handles.get_float_tensor::<B>(&self.desc.value);
916
917 let output = B::float_slice_assign(tensor, self.desc.ranges.as_slice(), value);
918
919 handles.register_float_tensor::<B>(&self.desc.out.id, output);
920 }
921 }
922
923 let streams = OperationStreams::with_inputs([&tensor, &value]);
924
925 let client = tensor.client.clone();
926 let desc =
927 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
928 client.create_empty_handle()
929 });
930
931 client
932 .register(
933 streams,
934 OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc.clone())),
935 SliceAssignOps::<B>::new(desc),
936 )
937 .output()
938 }
939
940 fn float_mask_where(
941 tensor: FloatTensor<Self>,
942 mask: BoolTensor<Self>,
943 value: FloatTensor<Self>,
944 ) -> FloatTensor<Self> {
945 #[derive(new, Debug)]
946 struct MaskWhereOps<B: FusionBackend> {
947 desc: MaskWhereOpIr,
948 _b: PhantomData<B>,
949 }
950
951 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
952 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
953 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
954 let value = handles.get_float_tensor::<B>(&self.desc.value);
955 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
956
957 let output = B::float_mask_where(tensor, mask, value);
958
959 handles.register_float_tensor::<B>(&self.desc.out.id, output);
960 }
961 }
962
963 let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
964
965 let client = tensor.client.clone();
966 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
967 client.create_empty_handle()
968 });
969
970 client
971 .register(
972 streams,
973 OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc.clone())),
974 MaskWhereOps::<B>::new(desc),
975 )
976 .output()
977 }
978
979 fn float_mask_fill(
980 tensor: FloatTensor<Self>,
981 mask: BoolTensor<Self>,
982 value: Scalar,
983 ) -> FloatTensor<Self> {
984 #[derive(new, Debug)]
985 struct MaskFillOps<B: FusionBackend> {
986 desc: MaskFillOpIr,
987 _b: PhantomData<B>,
988 }
989
990 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
991 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
992 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
993 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
994
995 let output = B::float_mask_fill(tensor, mask, self.desc.value.into());
996
997 handles.register_float_tensor::<B>(&self.desc.out.id, output);
998 }
999 }
1000
1001 let streams = OperationStreams::with_inputs([&tensor, &mask]);
1002
1003 let client = tensor.client.clone();
1004 let value = value.into();
1005 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
1006 client.create_empty_handle()
1007 });
1008
1009 client
1010 .register(
1011 streams,
1012 OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc.clone())),
1013 MaskFillOps::<B>::new(desc),
1014 )
1015 .output()
1016 }
1017
1018 fn float_equal(
1019 lhs: FloatTensor<Self>,
1020 rhs: FloatTensor<Self>,
1021 out_dtype: BoolDType,
1022 ) -> BoolTensor<Self> {
1023 binary_float_cmp_ops!(EqualOps, B::float_equal);
1024
1025 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1026
1027 let client = lhs.client.clone();
1028 let desc =
1029 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1030 client.create_empty_handle()
1031 });
1032
1033 client
1034 .register(
1035 streams,
1036 OperationIr::BaseFloat(BaseOperationIr::Equal(desc.clone())),
1037 EqualOps::<B>::new(desc),
1038 )
1039 .output()
1040 }
1041
1042 fn float_equal_elem(
1043 lhs: FloatTensor<Self>,
1044 rhs: Scalar,
1045 out_dtype: BoolDType,
1046 ) -> BoolTensor<Self> {
1047 scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem);
1048
1049 let streams = OperationStreams::with_inputs([&lhs]);
1050
1051 let client = lhs.client.clone();
1052 let rhs = rhs.into();
1053 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1054 client.create_empty_handle()
1055 });
1056
1057 client
1058 .register(
1059 streams,
1060 OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc.clone())),
1061 EqualElemOps::<B>::new(desc),
1062 )
1063 .output()
1064 }
1065
1066 fn float_greater(
1067 lhs: FloatTensor<Self>,
1068 rhs: FloatTensor<Self>,
1069 out_dtype: BoolDType,
1070 ) -> BoolTensor<Self> {
1071 binary_float_cmp_ops!(GreaterOps, B::float_greater);
1072
1073 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1074
1075 let client = lhs.client.clone();
1076 let desc =
1077 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1078 client.create_empty_handle()
1079 });
1080
1081 client
1082 .register(
1083 streams,
1084 OperationIr::NumericFloat(
1085 desc.lhs.dtype,
1086 NumericOperationIr::Greater(desc.clone()),
1087 ),
1088 GreaterOps::<B>::new(desc),
1089 )
1090 .output()
1091 }
1092
1093 fn float_greater_elem(
1094 lhs: FloatTensor<Self>,
1095 rhs: Scalar,
1096 out_dtype: BoolDType,
1097 ) -> BoolTensor<Self> {
1098 scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem);
1099
1100 let streams = OperationStreams::with_inputs([&lhs]);
1101
1102 let client = lhs.client.clone();
1103 let rhs = rhs.into();
1104 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1105 client.create_empty_handle()
1106 });
1107
1108 client
1109 .register(
1110 streams,
1111 OperationIr::NumericFloat(
1112 desc.lhs.dtype,
1113 NumericOperationIr::GreaterElem(desc.clone()),
1114 ),
1115 GreaterElemOps::<B>::new(desc),
1116 )
1117 .output()
1118 }
1119
1120 fn float_greater_equal(
1121 lhs: FloatTensor<Self>,
1122 rhs: FloatTensor<Self>,
1123 out_dtype: BoolDType,
1124 ) -> BoolTensor<Self> {
1125 binary_float_cmp_ops!(GreaterEqualOps, B::float_greater_equal);
1126
1127 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1128
1129 let client = lhs.client.clone();
1130 let desc =
1131 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1132 client.create_empty_handle()
1133 });
1134
1135 client
1136 .register(
1137 streams,
1138 OperationIr::NumericFloat(
1139 desc.lhs.dtype,
1140 NumericOperationIr::GreaterEqual(desc.clone()),
1141 ),
1142 GreaterEqualOps::<B>::new(desc),
1143 )
1144 .output()
1145 }
1146
1147 fn float_greater_equal_elem(
1148 lhs: FloatTensor<Self>,
1149 rhs: Scalar,
1150 out_dtype: BoolDType,
1151 ) -> BoolTensor<Self> {
1152 scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem);
1153
1154 let streams = OperationStreams::with_inputs([&lhs]);
1155
1156 let client = lhs.client.clone();
1157 let rhs = rhs.into();
1158 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1159 client.create_empty_handle()
1160 });
1161
1162 client
1163 .register(
1164 streams,
1165 OperationIr::NumericFloat(
1166 desc.lhs.dtype,
1167 NumericOperationIr::GreaterEqualElem(desc.clone()),
1168 ),
1169 GreaterEqualElemOps::<B>::new(desc),
1170 )
1171 .output()
1172 }
1173
1174 fn float_lower(
1175 lhs: FloatTensor<Self>,
1176 rhs: FloatTensor<Self>,
1177 out_dtype: BoolDType,
1178 ) -> BoolTensor<Self> {
1179 binary_float_cmp_ops!(LowerOps, B::float_lower);
1180
1181 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1182
1183 let client = lhs.client.clone();
1184 let desc =
1185 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1186 client.create_empty_handle()
1187 });
1188
1189 client
1190 .register(
1191 streams,
1192 OperationIr::NumericFloat(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())),
1193 LowerOps::<B>::new(desc),
1194 )
1195 .output()
1196 }
1197
1198 fn float_lower_elem(
1199 lhs: FloatTensor<Self>,
1200 rhs: Scalar,
1201 out_dtype: BoolDType,
1202 ) -> BoolTensor<Self> {
1203 scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem);
1204
1205 let streams = OperationStreams::with_inputs([&lhs]);
1206
1207 let client = lhs.client.clone();
1208 let rhs = rhs.into();
1209 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1210 client.create_empty_handle()
1211 });
1212
1213 client
1214 .register(
1215 streams,
1216 OperationIr::NumericFloat(
1217 desc.lhs.dtype,
1218 NumericOperationIr::LowerElem(desc.clone()),
1219 ),
1220 LowerElemOps::<B>::new(desc),
1221 )
1222 .output()
1223 }
1224
1225 fn float_lower_equal(
1226 lhs: FloatTensor<Self>,
1227 rhs: FloatTensor<Self>,
1228 out_dtype: BoolDType,
1229 ) -> BoolTensor<Self> {
1230 binary_float_cmp_ops!(LowerEqualOps, B::float_lower_equal);
1231
1232 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1233
1234 let client = lhs.client.clone();
1235 let desc =
1236 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
1237 client.create_empty_handle()
1238 });
1239
1240 client
1241 .register(
1242 streams,
1243 OperationIr::NumericFloat(
1244 desc.lhs.dtype,
1245 NumericOperationIr::LowerEqual(desc.clone()),
1246 ),
1247 LowerEqualOps::<B>::new(desc),
1248 )
1249 .output()
1250 }
1251
1252 fn float_lower_equal_elem(
1253 lhs: FloatTensor<Self>,
1254 rhs: Scalar,
1255 out_dtype: BoolDType,
1256 ) -> BoolTensor<Self> {
1257 scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem);
1258
1259 let streams = OperationStreams::with_inputs([&lhs]);
1260
1261 let client = lhs.client.clone();
1262 let rhs = rhs.into();
1263 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
1264 client.create_empty_handle()
1265 });
1266
1267 client
1268 .register(
1269 streams,
1270 OperationIr::NumericFloat(
1271 desc.lhs.dtype,
1272 NumericOperationIr::LowerEqualElem(desc.clone()),
1273 ),
1274 LowerEqualElemOps::<B>::new(desc),
1275 )
1276 .output()
1277 }
1278
1279 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1280 unary_float_ops!(SumOps, B::float_sum, reduce);
1281
1282 let streams = OperationStreams::with_inputs([&tensor]);
1283
1284 let client = tensor.client.clone();
1285 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1286
1287 client
1288 .register(
1289 streams,
1290 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sum(desc.clone())),
1291 SumOps::<B>::new(desc.into()),
1292 )
1293 .output()
1294 }
1295
1296 fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
1297 reduce_float_ops!(SumDimOps, |tensor, axis, _| B::float_sum_dim(tensor, axis));
1298
1299 let streams = OperationStreams::with_inputs([&tensor]);
1300
1301 let client = tensor.client.clone();
1302 let desc =
1303 ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
1304
1305 client
1306 .register(
1307 streams,
1308 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())),
1309 SumDimOps::<B>::new(desc),
1310 )
1311 .output()
1312 }
1313
1314 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1315 unary_float_ops!(ProdOps, B::float_prod, reduce);
1316
1317 let streams = OperationStreams::with_inputs([&tensor]);
1318
1319 let client = tensor.client.clone();
1320 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1321
1322 client
1323 .register(
1324 streams,
1325 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Prod(desc.clone())),
1326 ProdOps::<B>::new(desc.into()),
1327 )
1328 .output()
1329 }
1330
1331 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1332 reduce_float_ops!(ProdDimOps, |tensor, axis, _| B::float_prod_dim(
1333 tensor, axis
1334 ));
1335
1336 let streams = OperationStreams::with_inputs([&tensor]);
1337
1338 let client = tensor.client.clone();
1339 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1340
1341 client
1342 .register(
1343 streams,
1344 OperationIr::NumericFloat(
1345 desc.out.dtype,
1346 NumericOperationIr::ProdDim(desc.clone()),
1347 ),
1348 ProdDimOps::<B>::new(desc),
1349 )
1350 .output()
1351 }
1352
1353 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1354 unary_float_ops!(MeanOps, B::float_mean, reduce);
1355
1356 let streams = OperationStreams::with_inputs([&tensor]);
1357
1358 let client = tensor.client.clone();
1359 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1360
1361 client
1362 .register(
1363 streams,
1364 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mean(desc.clone())),
1365 MeanOps::<B>::new(desc.into()),
1366 )
1367 .output()
1368 }
1369
1370 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1371 reduce_float_ops!(MeanDimOps, |tensor, axis, _| B::float_mean_dim(
1372 tensor, axis
1373 ));
1374
1375 let streams = OperationStreams::with_inputs([&tensor]);
1376
1377 let client = tensor.client.clone();
1378 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1379
1380 client
1381 .register(
1382 streams,
1383 OperationIr::NumericFloat(
1384 desc.out.dtype,
1385 NumericOperationIr::MeanDim(desc.clone()),
1386 ),
1387 MeanDimOps::<B>::new(desc),
1388 )
1389 .output()
1390 }
1391
1392 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1393 #[derive(new, Debug)]
1394 struct CumsumOps<B: FusionBackend> {
1395 desc: DimOpIr,
1396 _b: PhantomData<B>,
1397 }
1398
1399 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1400 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1401 let input = handles.get_float_tensor::<B>(&self.desc.input);
1402 let output = B::float_cumsum(input, self.desc.axis);
1403 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1404 }
1405 }
1406
1407 let streams = OperationStreams::with_inputs([&tensor]);
1408
1409 let client = tensor.client.clone();
1410 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1411
1412 client
1413 .register(
1414 streams,
1415 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())),
1416 CumsumOps::<B>::new(desc),
1417 )
1418 .output()
1419 }
1420
1421 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1422 #[derive(new, Debug)]
1423 struct CumprodOps<B: FusionBackend> {
1424 desc: DimOpIr,
1425 _b: PhantomData<B>,
1426 }
1427
1428 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1429 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1430 let input = handles.get_float_tensor::<B>(&self.desc.input);
1431 let output = B::float_cumprod(input, self.desc.axis);
1432 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1433 }
1434 }
1435
1436 let streams = OperationStreams::with_inputs([&tensor]);
1437
1438 let client = tensor.client.clone();
1439 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1440
1441 client
1442 .register(
1443 streams,
1444 OperationIr::NumericFloat(
1445 desc.out.dtype,
1446 NumericOperationIr::CumProd(desc.clone()),
1447 ),
1448 CumprodOps::<B>::new(desc),
1449 )
1450 .output()
1451 }
1452
1453 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1454 #[derive(new, Debug)]
1455 struct CumminOps<B: FusionBackend> {
1456 desc: DimOpIr,
1457 _b: PhantomData<B>,
1458 }
1459
1460 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1461 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1462 let input = handles.get_float_tensor::<B>(&self.desc.input);
1463 let output = B::float_cummin(input, self.desc.axis);
1464 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1465 }
1466 }
1467
1468 let streams = OperationStreams::with_inputs([&tensor]);
1469
1470 let client = tensor.client.clone();
1471 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1472
1473 client
1474 .register(
1475 streams,
1476 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())),
1477 CumminOps::<B>::new(desc),
1478 )
1479 .output()
1480 }
1481
1482 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1483 #[derive(new, Debug)]
1484 struct CummaxOps<B: FusionBackend> {
1485 desc: DimOpIr,
1486 _b: PhantomData<B>,
1487 }
1488
1489 impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1490 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1491 let input = handles.get_float_tensor::<B>(&self.desc.input);
1492 let output = B::float_cummax(input, self.desc.axis);
1493 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1494 }
1495 }
1496
1497 let streams = OperationStreams::with_inputs([&tensor]);
1498
1499 let client = tensor.client.clone();
1500 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1501
1502 client
1503 .register(
1504 streams,
1505 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())),
1506 CummaxOps::<B>::new(desc),
1507 )
1508 .output()
1509 }
1510
1511 fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
1512 unary_float_ops!(ExpOps, B::float_exp);
1513
1514 let streams = OperationStreams::with_inputs([&lhs]);
1515
1516 let client = lhs.client.clone();
1517 let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());
1518
1519 client
1520 .register(
1521 streams,
1522 OperationIr::Float(desc.out.dtype, FloatOperationIr::Exp(desc.clone())),
1523 ExpOps::<B>::new(desc),
1524 )
1525 .output()
1526 }
1527
1528 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1529 unary_float_ops!(LogOps, B::float_log);
1530
1531 let streams = OperationStreams::with_inputs([&tensor]);
1532
1533 let client = tensor.client.clone();
1534 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1535
1536 client
1537 .register(
1538 streams,
1539 OperationIr::Float(desc.out.dtype, FloatOperationIr::Log(desc.clone())),
1540 LogOps::<B>::new(desc),
1541 )
1542 .output()
1543 }
1544
1545 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1546 unary_float_ops!(Log1pOps, B::float_log1p);
1547
1548 let streams = OperationStreams::with_inputs([&tensor]);
1549
1550 let client = tensor.client.clone();
1551 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1552
1553 client
1554 .register(
1555 streams,
1556 OperationIr::Float(desc.out.dtype, FloatOperationIr::Log1p(desc.clone())),
1557 Log1pOps::<B>::new(desc),
1558 )
1559 .output()
1560 }
1561
1562 fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
1563 scalar_float_ops!(PowfOps, B::float_powf_scalar_impl);
1564
1565 let streams = OperationStreams::with_inputs([&lhs]);
1566
1567 let client = lhs.client.clone();
1568 let rhs = rhs.into();
1569 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1570
1571 client
1572 .register(
1573 streams,
1574 OperationIr::Float(desc.out.dtype, FloatOperationIr::PowfScalar(desc.clone())),
1575 PowfOps::<B>::new(desc),
1576 )
1577 .output()
1578 }
1579
1580 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1581 unary_float_ops!(SqrtOps, B::float_sqrt);
1582
1583 let streams = OperationStreams::with_inputs([&tensor]);
1584
1585 let client = tensor.client.clone();
1586 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1587
1588 client
1589 .register(
1590 streams,
1591 OperationIr::Float(desc.out.dtype, FloatOperationIr::Sqrt(desc.clone())),
1592 SqrtOps::<B>::new(desc),
1593 )
1594 .output()
1595 }
1596
1597 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1598 unary_float_ops!(AbsOps, B::float_abs);
1599
1600 let streams = OperationStreams::with_inputs([&tensor]);
1601
1602 let client = tensor.client.clone();
1603 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1604
1605 client
1606 .register(
1607 streams,
1608 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Abs(desc.clone())),
1609 AbsOps::<B>::new(desc),
1610 )
1611 .output()
1612 }
1613
1614 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1615 unary_float_ops!(CosOps, B::float_cos);
1616
1617 let streams = OperationStreams::with_inputs([&tensor]);
1618
1619 let client = tensor.client.clone();
1620 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1621
1622 client
1623 .register(
1624 streams,
1625 OperationIr::Float(desc.out.dtype, FloatOperationIr::Cos(desc.clone())),
1626 CosOps::<B>::new(desc),
1627 )
1628 .output()
1629 }
1630
1631 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1632 unary_float_ops!(SinOps, B::float_sin);
1633
1634 let streams = OperationStreams::with_inputs([&tensor]);
1635
1636 let client = tensor.client.clone();
1637 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1638
1639 client
1640 .register(
1641 streams,
1642 OperationIr::Float(desc.out.dtype, FloatOperationIr::Sin(desc.clone())),
1643 SinOps::<B>::new(desc),
1644 )
1645 .output()
1646 }
1647
1648 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1649 unary_float_ops!(TanOps, B::float_tan);
1650
1651 let streams = OperationStreams::with_inputs([&tensor]);
1652
1653 let client = tensor.client.clone();
1654 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1655
1656 client
1657 .register(
1658 streams,
1659 OperationIr::Float(desc.out.dtype, FloatOperationIr::Tan(desc.clone())),
1660 TanOps::<B>::new(desc),
1661 )
1662 .output()
1663 }
1664
1665 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1666 unary_float_ops!(CoshOps, B::float_cosh);
1667
1668 let streams = OperationStreams::with_inputs([&tensor]);
1669
1670 let client = tensor.client.clone();
1671 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1672
1673 client
1674 .register(
1675 streams,
1676 OperationIr::Float(desc.out.dtype, FloatOperationIr::Cosh(desc.clone())),
1677 CoshOps::<B>::new(desc),
1678 )
1679 .output()
1680 }
1681
1682 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1683 unary_float_ops!(SinhOps, B::float_sinh);
1684
1685 let streams = OperationStreams::with_inputs([&tensor]);
1686
1687 let client = tensor.client.clone();
1688 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1689
1690 client
1691 .register(
1692 streams,
1693 OperationIr::Float(desc.out.dtype, FloatOperationIr::Sinh(desc.clone())),
1694 SinhOps::<B>::new(desc),
1695 )
1696 .output()
1697 }
1698
1699 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1700 unary_float_ops!(TanhOps, B::float_tanh);
1701
1702 let streams = OperationStreams::with_inputs([&tensor]);
1703
1704 let client = tensor.client.clone();
1705 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1706
1707 client
1708 .register(
1709 streams,
1710 OperationIr::Float(desc.out.dtype, FloatOperationIr::Tanh(desc.clone())),
1711 TanhOps::<B>::new(desc),
1712 )
1713 .output()
1714 }
1715
1716 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1717 unary_float_ops!(ArcCosOps, B::float_acos);
1718
1719 let streams = OperationStreams::with_inputs([&tensor]);
1720
1721 let client = tensor.client.clone();
1722 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1723
1724 client
1725 .register(
1726 streams,
1727 OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCos(desc.clone())),
1728 ArcCosOps::<B>::new(desc),
1729 )
1730 .output()
1731 }
1732
1733 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1734 unary_float_ops!(ArcCoshOps, B::float_acosh);
1735
1736 let streams = OperationStreams::with_inputs([&tensor]);
1737
1738 let client = tensor.client.clone();
1739 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1740
1741 client
1742 .register(
1743 streams,
1744 OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCosh(desc.clone())),
1745 ArcCoshOps::<B>::new(desc),
1746 )
1747 .output()
1748 }
1749
1750 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1751 unary_float_ops!(ArcSinOps, B::float_asin);
1752
1753 let streams = OperationStreams::with_inputs([&tensor]);
1754
1755 let client = tensor.client.clone();
1756 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1757
1758 client
1759 .register(
1760 streams,
1761 OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSin(desc.clone())),
1762 ArcSinOps::<B>::new(desc),
1763 )
1764 .output()
1765 }
1766
1767 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1768 unary_float_ops!(ArcSinhOps, B::float_asinh);
1769
1770 let streams = OperationStreams::with_inputs([&tensor]);
1771
1772 let client = tensor.client.clone();
1773 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1774
1775 client
1776 .register(
1777 streams,
1778 OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSinh(desc.clone())),
1779 ArcSinhOps::<B>::new(desc),
1780 )
1781 .output()
1782 }
1783
1784 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1785 unary_float_ops!(ArcTanOps, B::float_atan);
1786
1787 let streams = OperationStreams::with_inputs([&tensor]);
1788
1789 let client = tensor.client.clone();
1790 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1791
1792 client
1793 .register(
1794 streams,
1795 OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan(desc.clone())),
1796 ArcTanOps::<B>::new(desc),
1797 )
1798 .output()
1799 }
1800
1801 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1802 unary_float_ops!(ArcTanhOps, B::float_atanh);
1803
1804 let streams = OperationStreams::with_inputs([&tensor]);
1805
1806 let client = tensor.client.clone();
1807 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1808
1809 client
1810 .register(
1811 streams,
1812 OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTanh(desc.clone())),
1813 ArcTanhOps::<B>::new(desc),
1814 )
1815 .output()
1816 }
1817
1818 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1819 binary_float_ops!(ArcTan2Ops, B::float_atan2);
1820
1821 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1822
1823 let client = lhs.client.clone();
1824 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1825 client.create_empty_handle()
1826 });
1827
1828 client
1829 .register(
1830 streams,
1831 OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan2(desc.clone())),
1832 ArcTan2Ops::<B>::new(desc),
1833 )
1834 .output()
1835 }
1836
1837 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1838 unary_float_ops!(Recip, B::float_recip);
1839
1840 let streams = OperationStreams::with_inputs([&tensor]);
1841
1842 let client = tensor.client.clone();
1843 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1844
1845 client
1846 .register(
1847 streams,
1848 OperationIr::Float(desc.out.dtype, FloatOperationIr::Recip(desc.clone())),
1849 Recip::<B>::new(desc),
1850 )
1851 .output()
1852 }
1853
1854 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1855 unary_float_ops!(TanhOps, B::float_erf);
1856
1857 let streams = OperationStreams::with_inputs([&tensor]);
1858
1859 let client = tensor.client.clone();
1860 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1861
1862 client
1863 .register(
1864 streams,
1865 OperationIr::Float(desc.out.dtype, FloatOperationIr::Erf(desc.clone())),
1866 TanhOps::<B>::new(desc),
1867 )
1868 .output()
1869 }
1870
1871 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1872 #[derive(new, Debug)]
1873 struct CatOps<B: FusionBackend> {
1874 desc: CatOpIr,
1875 _b: PhantomData<B>,
1876 }
1877
1878 impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
1879 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1880 let tensors = self
1881 .desc
1882 .tensors
1883 .iter()
1884 .map(|tensor| handles.get_float_tensor::<B>(tensor))
1885 .collect();
1886
1887 let output = B::float_cat(tensors, self.desc.dim);
1888
1889 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1890 }
1891 }
1892
1893 let streams = OperationStreams::with_inputs(&tensors);
1894
1895 let client = tensors.first().unwrap().client.clone();
1896 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
1897 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
1898
1899 client
1900 .register(
1901 streams,
1902 OperationIr::BaseFloat(BaseOperationIr::Cat(desc.clone())),
1903 CatOps::<B>::new(desc),
1904 )
1905 .output()
1906 }
1907
1908 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
1909 reduce_float2int_ops!(ArgMaxOps, |input, axis, _, dtype| B::float_argmax(
1910 input, axis, dtype
1911 ));
1912
1913 let streams = OperationStreams::with_inputs([&tensor]);
1914
1915 let client = tensor.client.clone();
1916 let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
1917 client.create_empty_handle()
1918 });
1919
1920 client
1921 .register(
1922 streams,
1923 OperationIr::NumericFloat(
1924 desc.input.dtype,
1925 NumericOperationIr::ArgMax(desc.clone()),
1926 ),
1927 ArgMaxOps::<B>::new(desc),
1928 )
1929 .output()
1930 }
1931
1932 fn float_argtopk(
1933 tensor: FloatTensor<Self>,
1934 dim: usize,
1935 k: usize,
1936 out_dtype: IntDType,
1937 ) -> IntTensor<Self> {
1938 reduce_float2int_ops!(ArgTopKOps, B::float_argtopk);
1939
1940 let streams = OperationStreams::with_inputs([&tensor]);
1941
1942 let client = tensor.client.clone();
1943 let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, k, out_dtype.into(), || {
1944 client.create_empty_handle()
1945 });
1946
1947 client
1948 .register(
1949 streams,
1950 OperationIr::NumericFloat(
1951 desc.input.dtype,
1952 NumericOperationIr::ArgTopK(desc.clone()),
1953 ),
1954 ArgTopKOps::<B>::new(desc),
1955 )
1956 .output()
1957 }
1958
1959 fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
1960 reduce_float_ops!(TopKOps, B::float_topk);
1961
1962 let streams = OperationStreams::with_inputs([&tensor]);
1963
1964 let client = tensor.client.clone();
1965 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1966
1967 client
1968 .register(
1969 streams,
1970 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::TopK(desc.clone())),
1971 TopKOps::<B>::new(desc),
1972 )
1973 .output()
1974 }
1975
1976 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1977 #[derive(new, Debug)]
1978 struct RepeatDimOps<B: FusionBackend> {
1979 desc: RepeatDimOpIr,
1980 _b: PhantomData<B>,
1981 }
1982
1983 impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1984 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1985 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
1986
1987 let output = B::float_repeat_dim(tensor, self.desc.dim, self.desc.times);
1988
1989 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1990 }
1991 }
1992
1993 let streams = OperationStreams::with_inputs([&tensor]);
1994
1995 let client = tensor.client.clone();
1996 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1997 client.create_empty_handle()
1998 });
1999
2000 client
2001 .register(
2002 streams,
2003 OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc.clone())),
2004 RepeatDimOps::<B>::new(desc),
2005 )
2006 .output()
2007 }
2008
2009 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
2010 reduce_float2int_ops!(ArgMinOps, |input, axis, _, dtype| B::float_argmin(
2011 input, axis, dtype
2012 ));
2013
2014 let streams = OperationStreams::with_inputs([&tensor]);
2015
2016 let client = tensor.client.clone();
2017 let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
2018 client.create_empty_handle()
2019 });
2020
2021 client
2022 .register(
2023 streams,
2024 OperationIr::NumericFloat(
2025 desc.input.dtype,
2026 NumericOperationIr::ArgMin(desc.clone()),
2027 ),
2028 ArgMinOps::<B>::new(desc),
2029 )
2030 .output()
2031 }
2032
2033 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2034 unary_float_ops!(MaxOps, B::float_max, reduce);
2035
2036 let streams = OperationStreams::with_inputs([&tensor]);
2037
2038 let client = tensor.client.clone();
2039 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2040
2041 client
2042 .register(
2043 streams,
2044 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Max(desc.clone())),
2045 MaxOps::<B>::new(desc.into()),
2046 )
2047 .output()
2048 }
2049
2050 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2051 reduce_float_ops!(MaxDimOps, |tensor, axis, _| B::float_max_dim(tensor, axis));
2052
2053 let streams = OperationStreams::with_inputs([&tensor]);
2054
2055 let client = tensor.client.clone();
2056 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
2057
2058 client
2059 .register(
2060 streams,
2061 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())),
2062 MaxDimOps::<B>::new(desc),
2063 )
2064 .output()
2065 }
2066
2067 fn float_max_dim_with_indices(
2068 tensor: FloatTensor<Self>,
2069 dim: usize,
2070 indices_dtype: IntDType,
2071 ) -> (FloatTensor<Self>, IntTensor<Self>) {
2072 #[derive(new, Debug)]
2073 struct MaxDimWithIndicesOps<B: FusionBackend> {
2074 desc: ReduceDimWithIndicesOpIr,
2075 _b: PhantomData<B>,
2076 }
2077
2078 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
2079 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2080 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2081 let (output, indices) = B::float_max_dim_with_indices(
2082 tensor,
2083 self.desc.dim,
2084 self.desc.out_indices.dtype.into(),
2085 );
2086
2087 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2088 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2089 }
2090 }
2091
2092 let streams = OperationStreams::with_inputs([&tensor]);
2093
2094 let client = tensor.client.clone();
2095 let desc =
2096 ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
2097 client.create_empty_handle()
2098 });
2099
2100 client
2101 .register(
2102 streams,
2103 OperationIr::NumericFloat(
2104 desc.tensor.dtype,
2105 NumericOperationIr::MaxDimWithIndices(desc.clone()),
2106 ),
2107 MaxDimWithIndicesOps::<B>::new(desc),
2108 )
2109 .outputs()
2110 .into()
2111 }
2112
2113 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2114 unary_float_ops!(MinOps, B::float_min, reduce);
2115
2116 let streams = OperationStreams::with_inputs([&tensor]);
2117
2118 let client = tensor.client.clone();
2119 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2120
2121 client
2122 .register(
2123 streams,
2124 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Min(desc.clone())),
2125 MinOps::<B>::new(desc.into()),
2126 )
2127 .output()
2128 }
2129
2130 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2131 reduce_float_ops!(MinDimOps, |tensor, axis, _| B::float_min_dim(tensor, axis));
2132
2133 let streams = OperationStreams::with_inputs([&tensor]);
2134
2135 let client = tensor.client.clone();
2136 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
2137
2138 client
2139 .register(
2140 streams,
2141 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())),
2142 MinDimOps::<B>::new(desc),
2143 )
2144 .output()
2145 }
2146
2147 fn float_min_dim_with_indices(
2148 tensor: FloatTensor<Self>,
2149 dim: usize,
2150 indices_dtype: IntDType,
2151 ) -> (FloatTensor<Self>, IntTensor<Self>) {
2152 #[derive(new, Debug)]
2153 struct MinDimWithIndicesOps<B: FusionBackend> {
2154 desc: ReduceDimWithIndicesOpIr,
2155 _b: PhantomData<B>,
2156 }
2157
2158 impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
2159 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2160 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2161 let (output, indices) = B::float_min_dim_with_indices(
2162 tensor,
2163 self.desc.dim,
2164 self.desc.out_indices.dtype.into(),
2165 );
2166
2167 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2168 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2169 }
2170 }
2171 let streams = OperationStreams::with_inputs([&tensor]);
2172
2173 let client = tensor.client.clone();
2174 let desc =
2175 ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
2176 client.create_empty_handle()
2177 });
2178
2179 client
2180 .register(
2181 streams,
2182 OperationIr::NumericFloat(
2183 desc.tensor.dtype,
2184 NumericOperationIr::MinDimWithIndices(desc.clone()),
2185 ),
2186 MinDimWithIndicesOps::<B>::new(desc),
2187 )
2188 .outputs()
2189 .into()
2190 }
2191
2192 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2193 unary_float_ops!(MaxAbsOps, B::float_max_abs, reduce);
2194
2195 let streams = OperationStreams::with_inputs([&tensor]);
2196
2197 let client = tensor.client.clone();
2198 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2199
2200 client
2201 .register(
2202 streams,
2203 OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())),
2204 MaxAbsOps::<B>::new(desc.into()),
2205 )
2206 .output()
2207 }
2208
2209 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2210 reduce_float_ops!(MaxAbsDimOps, |tensor, axis, _| B::float_max_abs_dim(
2211 tensor, axis
2212 ));
2213
2214 let streams = OperationStreams::with_inputs([&tensor]);
2215
2216 let client = tensor.client.clone();
2217 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
2218
2219 client
2220 .register(
2221 streams,
2222 OperationIr::NumericFloat(
2223 desc.out.dtype,
2224 NumericOperationIr::MaxAbsDim(desc.clone()),
2225 ),
2226 MaxAbsDimOps::<B>::new(desc),
2227 )
2228 .output()
2229 }
2230
2231 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
2233 binary_float_ops!(PowOps, B::float_powf);
2234
2235 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
2236
2237 let client = lhs.client.clone();
2238 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
2239 client.create_empty_handle()
2240 });
2241
2242 client
2243 .register(
2244 streams,
2245 OperationIr::Float(desc.out.dtype, FloatOperationIr::Powf(desc.clone())),
2246 PowOps::<B>::new(desc),
2247 )
2248 .output()
2249 }
2250
2251 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2252 #[derive(new, Debug)]
2253 struct PermuteDimsOps<B: FusionBackend> {
2254 desc: PermuteOpIr,
2255 _b: PhantomData<B>,
2256 }
2257
2258 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
2259 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2260 let input = handles.get_float_tensor::<B>(&self.desc.input);
2261 let output = B::float_permute(input, self.desc.axes.as_slice());
2262 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2263 }
2264 }
2265
2266 let streams = OperationStreams::with_inputs([&tensor]);
2267
2268 let client = tensor.client.clone();
2269 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
2270 client.create_empty_handle()
2271 });
2272
2273 client
2274 .register(
2275 streams,
2276 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
2277 PermuteDimsOps::<B>::new(desc),
2278 )
2279 .output()
2280 }
2281
2282 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
2283 #[derive(new, Debug)]
2284 struct ExpandOps<B: FusionBackend> {
2285 desc: ShapeOpIr,
2286 _b: PhantomData<B>,
2287 }
2288
2289 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
2290 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2291 let input = handles.get_float_tensor::<B>(&self.desc.input);
2292 let output = B::float_expand(input, self.desc.out.shape.clone());
2293
2294 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2295 }
2296 }
2297
2298 let streams = OperationStreams::with_inputs([&tensor]);
2299
2300 let client = tensor.client.clone();
2301 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
2302
2303 client
2304 .register(
2305 streams,
2306 OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
2307 ExpandOps::<B>::new(desc),
2308 )
2309 .output()
2310 }
2311
2312 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2313 #[derive(new, Debug)]
2314 struct FlipOps<B: FusionBackend> {
2315 desc: FlipOpIr,
2316 _b: PhantomData<B>,
2317 }
2318
2319 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
2320 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2321 let input = handles.get_float_tensor::<B>(&self.desc.input);
2322 let output = B::float_flip(input, &self.desc.axes);
2323 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2324 }
2325 }
2326
2327 let streams = OperationStreams::with_inputs([&tensor]);
2328
2329 let client = tensor.client.clone();
2330 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
2331 client.create_empty_handle()
2332 });
2333
2334 client
2335 .register(
2336 streams,
2337 OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
2338 FlipOps::<B>::new(desc),
2339 )
2340 .output()
2341 }
2342
2343 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2344 unary_float_ops!(RoundOps, B::float_round);
2345
2346 let streams = OperationStreams::with_inputs([&tensor]);
2347
2348 let client = tensor.client.clone();
2349 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2350
2351 client
2352 .register(
2353 streams,
2354 OperationIr::Float(desc.out.dtype, FloatOperationIr::Round(desc.clone())),
2355 RoundOps::<B>::new(desc),
2356 )
2357 .output()
2358 }
2359
2360 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2361 unary_float_ops!(FloorOps, B::float_floor);
2362
2363 let streams = OperationStreams::with_inputs([&tensor]);
2364
2365 let client = tensor.client.clone();
2366 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2367
2368 client
2369 .register(
2370 streams,
2371 OperationIr::Float(desc.out.dtype, FloatOperationIr::Floor(desc.clone())),
2372 FloorOps::<B>::new(desc),
2373 )
2374 .output()
2375 }
2376
2377 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2378 unary_float_ops!(CeilOps, B::float_ceil);
2379
2380 let streams = OperationStreams::with_inputs([&tensor]);
2381
2382 let client = tensor.client.clone();
2383 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2384
2385 client
2386 .register(
2387 streams,
2388 OperationIr::Float(desc.out.dtype, FloatOperationIr::Ceil(desc.clone())),
2389 CeilOps::<B>::new(desc),
2390 )
2391 .output()
2392 }
2393
2394 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2395 unary_float_ops!(TruncOps, B::float_trunc);
2396
2397 let streams = OperationStreams::with_inputs([&tensor]);
2398
2399 let client = tensor.client.clone();
2400 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
2401
2402 client
2403 .register(
2404 streams,
2405 OperationIr::Float(desc.out.dtype, FloatOperationIr::Trunc(desc.clone())),
2406 TruncOps::<B>::new(desc),
2407 )
2408 .output()
2409 }
2410
2411 fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {
2412 #[derive(new, Debug)]
2413 struct CastOps<B: FusionBackend> {
2414 desc: CastOpIr,
2415 dtype: burn_backend::FloatDType,
2416 _b: PhantomData<B>,
2417 }
2418
2419 impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2420 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2421 let input = handles.get_float_tensor::<B>(&self.desc.input);
2422 let output: B::FloatTensorPrimitive = B::float_cast(input, self.dtype);
2423 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2424 }
2425 }
2426
2427 let streams = OperationStreams::with_inputs([&tensor]);
2428
2429 let client = tensor.client.clone();
2430 let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
2431 client.create_empty_handle()
2432 });
2433
2434 client
2435 .register(
2436 streams,
2437 OperationIr::BaseFloat(BaseOperationIr::Cast(desc.clone())),
2438 CastOps::<B>::new(desc, dtype),
2439 )
2440 .output()
2441 }
2442
2443 fn float_unfold(
2444 tensor: FloatTensor<Self>,
2445 dim: usize,
2446 size: usize,
2447 step: usize,
2448 ) -> FloatTensor<Self> {
2449 #[derive(new, Debug)]
2450 struct UnfoldOps<B: FusionBackend> {
2451 desc: UnfoldOpIr,
2452 _b: PhantomData<B>,
2453 }
2454
2455 impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2456 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2457 let input = handles.get_float_tensor::<B>(&self.desc.input);
2458 let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2459
2460 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2461 }
2462 }
2463
2464 let streams = OperationStreams::with_inputs([&tensor]);
2465
2466 let client = tensor.client.clone();
2467 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
2468 client.create_empty_handle()
2469 });
2470
2471 client
2472 .register(
2473 streams,
2474 OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())),
2475 UnfoldOps::<B>::new(desc),
2476 )
2477 .output()
2478 }
2479
2480 fn float_is_nan(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
2481 #[derive(new, Debug)]
2482 struct IsNanOps<B: FusionBackend> {
2483 desc: UnaryOpIr,
2484 _b: PhantomData<B>,
2485 }
2486 impl<B: FusionBackend> Operation<B::FusionRuntime> for IsNanOps<B> {
2487 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2488 let input = handles.get_float_tensor::<B>(&self.desc.input);
2489 let output = B::float_is_nan(input, self.desc.out.dtype.into());
2490 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2491 }
2492 }
2493
2494 let streams = OperationStreams::with_inputs([&tensor]);
2495
2496 let client = tensor.client.clone();
2497 let desc = UnaryOpIr::create_comparison(tensor.into_ir(), out_dtype.into(), || {
2498 client.create_empty_handle()
2499 });
2500
2501 client
2502 .register(
2503 streams,
2504 OperationIr::Float(desc.input.dtype, FloatOperationIr::IsNan(desc.clone())),
2505 IsNanOps::<B>::new(desc),
2506 )
2507 .output()
2508 }
2509
2510 fn float_is_inf(tensor: FloatTensor<Self>, out_dtype: BoolDType) -> BoolTensor<Self> {
2511 #[derive(new, Debug)]
2512 struct IsInfOps<B: FusionBackend> {
2513 desc: UnaryOpIr,
2514 _b: PhantomData<B>,
2515 }
2516 impl<B: FusionBackend> Operation<B::FusionRuntime> for IsInfOps<B> {
2517 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2518 let input = handles.get_float_tensor::<B>(&self.desc.input);
2519 let output = B::float_is_inf(input, self.desc.out.dtype.into());
2520 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2521 }
2522 }
2523
2524 let streams = OperationStreams::with_inputs([&tensor]);
2525
2526 let client = tensor.client.clone();
2527 let desc = UnaryOpIr::create_comparison(tensor.into_ir(), out_dtype.into(), || {
2528 client.create_empty_handle()
2529 });
2530
2531 client
2532 .register(
2533 streams,
2534 OperationIr::Float(desc.input.dtype, FloatOperationIr::IsInf(desc.clone())),
2535 IsInfOps::<B>::new(desc),
2536 )
2537 .output()
2538 }
2539
2540 fn float_grid_sample_2d(
2541 tensor: FloatTensor<Self>,
2542 grid: FloatTensor<Self>,
2543 options: GridSampleOptions,
2544 ) -> FloatTensor<Self> {
2545 #[derive(new, Debug)]
2546 struct GridSample2dOps<B: FusionBackend> {
2547 desc: GridSample2dOpIr,
2548 _b: PhantomData<B>,
2549 }
2550
2551 impl<B: FusionBackend> Operation<B::FusionRuntime> for GridSample2dOps<B> {
2552 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2553 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2554 let grid = handles.get_float_tensor::<B>(&self.desc.grid);
2555 let output =
2556 B::float_grid_sample_2d(tensor, grid, self.desc.options.clone().into());
2557 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2558 }
2559 }
2560
2561 let streams = OperationStreams::with_inputs([&tensor, &grid]);
2562
2563 let client = tensor.client.clone();
2564 let desc =
2565 GridSample2dOpIr::create(tensor.into_ir(), grid.into_ir(), options.into(), || {
2566 client.create_empty_handle()
2567 });
2568
2569 client
2570 .register(
2571 streams,
2572 OperationIr::Float(desc.out.dtype, FloatOperationIr::GridSample2d(desc.clone())),
2573 GridSample2dOps::<B>::new(desc),
2574 )
2575 .output()
2576 }
2577}