1use super::NoOp;
2use crate::{
3 Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops, get_client,
4 ops::binary::check_binary_op_types,
5 reduce_float_ops, reduce_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
6 stream::{OperationStreams, StreamId, execution::Operation},
7 unary_float_ops,
8};
9use burn_ir::*;
10use burn_tensor::ops::unfold::calculate_unfold_shape;
11use burn_tensor::{
12 Device, Distribution, Element, FloatDType, Shape, Slice, TensorData, TensorMetadata,
13 ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
14};
15use std::marker::PhantomData;
16
17impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
18 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
19 let stream = StreamId::current();
20 let client = get_client::<B>(&device.clone());
21 let dtype = data.dtype;
22 let tensor = B::float_from_data(data, device);
23 let shape = tensor.shape();
24
25 let handle = B::float_tensor_handle(tensor);
26 let out = client.register_tensor(handle, shape, stream, dtype);
27 let desc = out.to_ir_out();
28
29 client.register(
30 OperationStreams::default(),
31 OperationIr::Init(InitOperationIr { out: desc }),
32 NoOp::<B>::new(),
33 );
34
35 out
36 }
37
38 fn float_random(
39 shape: Shape,
40 distribution: Distribution,
41 device: &Device<Self>,
42 ) -> FloatTensor<Self> {
43 #[derive(new, Debug)]
44 struct RandomOps<B: FusionBackend> {
45 desc: RandomOpIr,
46 device: Device<B>,
47 }
48
49 impl<B: FusionBackend> Operation<B::FusionRuntime> for RandomOps<B> {
50 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
51 let output: B::FloatTensorPrimitive = B::float_random(
52 self.desc.out.shape.clone(),
53 self.desc.distribution,
54 &self.device,
55 );
56 handles.register_float_tensor::<B>(&self.desc.out.id, output);
57 }
58 }
59
60 let client = get_client::<B>(&device.clone());
61 let out = client.tensor_uninitialized(shape, B::FloatElem::dtype());
62
63 let desc = RandomOpIr {
64 out: out.to_ir_out(),
65 distribution,
66 };
67 client.register(
68 OperationStreams::default(),
69 OperationIr::Float(
70 FloatElem::<Self>::dtype(),
71 FloatOperationIr::Random(desc.clone()),
72 ),
73 RandomOps::<B>::new(desc, device.clone()),
74 );
75
76 out
77 }
78
79 fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
80 #[derive(new, Debug)]
81 struct ZerosOps<B: FusionBackend> {
82 out: TensorIr,
83 device: Device<B>,
84 }
85
86 impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
87 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
88 let shape = self.out.shape.clone();
89 let output = B::float_zeros(shape, &self.device, self.out.dtype.into());
90 handles.register_float_tensor::<B>(&self.out.id, output);
91 }
92 }
93
94 let dtype = dtype.into();
95 let client = get_client::<B>(&device.clone());
96 let out = client.tensor_uninitialized(shape, dtype);
97
98 let desc = out.to_ir_out();
99 client.register(
100 OperationStreams::default(),
101 OperationIr::NumericFloat(dtype, NumericOperationIr::Zeros(desc.clone())),
102 ZerosOps::<B>::new(desc, device.clone()),
103 );
104
105 out
106 }
107
108 fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
109 #[derive(new, Debug)]
110 struct OnesOps<B: FusionBackend> {
111 out: TensorIr,
112 device: Device<B>,
113 }
114
115 impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
116 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
117 let shape = self.out.shape.clone();
118 let output = B::float_ones(shape, &self.device, self.out.dtype.into());
119 handles.register_float_tensor::<B>(&self.out.id, output);
120 }
121 }
122
123 let dtype = dtype.into();
124 let client = get_client::<B>(&device.clone());
125 let out = client.tensor_uninitialized(shape, dtype);
126
127 let desc = out.to_ir_out();
128 client.register(
129 OperationStreams::default(),
130 OperationIr::NumericFloat(dtype, NumericOperationIr::Ones(desc.clone())),
131 OnesOps::<B>::new(desc, device.clone()),
132 );
133
134 out
135 }
136
137 fn float_full(
138 shape: Shape,
139 fill_value: FloatElem<Self>,
140 device: &Device<Self>,
141 dtype: FloatDType,
142 ) -> FloatTensor<Self> {
143 #[derive(new, Debug)]
144 struct FullOps<B: FusionBackend> {
145 out: TensorIr,
146 elem: ScalarIr,
147 device: Device<B>,
148 }
149
150 impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
151 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
152 let shape = self.out.shape.clone();
153 let output: B::FloatTensorPrimitive =
154 B::float_full(shape, self.elem.elem(), &self.device, self.out.dtype.into());
155 handles.register_float_tensor::<B>(&self.out.id, output);
156 }
157 }
158
159 let dtype = dtype.into();
160 let client = get_client::<B>(&device.clone());
161 let out = client.tensor_uninitialized(shape, dtype);
162
163 let desc = (out.to_ir_out(), ScalarIr::with_dtype(fill_value, &dtype));
164 client.register(
165 OperationStreams::default(),
166 OperationIr::NumericFloat(dtype, NumericOperationIr::Full(desc.clone())),
167 FullOps::<B>::new(desc.0, desc.1, device.clone()),
168 );
169
170 out
171 }
172
173 async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
174 tensor.into_data::<B>().await
175 }
176
177 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
178 tensor.client.device().clone()
179 }
180
181 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
182 let device_original: &B::Device = tensor.client.device();
183 let device_target: B::Device = device.clone();
184
185 if device_original == &device_target {
186 return tensor;
187 }
188
189 let id = tensor.stream;
190 let client_target = get_client::<B>(&device_target);
191 let client_original = tensor.client.clone();
192
193 client_original
194 .clone()
195 .change_client_float::<B>(tensor.into_ir(), client_target, id)
196 }
197
198 fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
199 #[derive(new, Debug)]
200 struct IntoIntOps<B: FusionBackend> {
201 desc: UnaryOpIr,
202 _b: PhantomData<B>,
203 }
204
205 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
206 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
207 let input = handles.get_float_tensor::<B>(&self.desc.input);
208 let output = B::float_into_int(input);
209
210 handles.register_int_tensor::<B>(&self.desc.out.id, output);
211 }
212 }
213
214 let mut streams = OperationStreams::default();
215 streams.tensor(&tensor);
216 let dtype = tensor.dtype;
217 let out = tensor
218 .client
219 .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype());
220
221 let desc = UnaryOpIr {
222 input: tensor.into_ir(),
223 out: out.to_ir_out(),
224 };
225 out.client.register(
226 streams,
227 OperationIr::Float(dtype, FloatOperationIr::IntoInt(desc.clone())),
228 IntoIntOps::<B>::new(desc),
229 );
230
231 out
232 }
233
234 fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
235 #[derive(new, Debug)]
236 struct EmptyOps<B: FusionBackend> {
237 desc: TensorIr,
238 device: Device<B>,
239 }
240
241 impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
242 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
243 let output = B::float_empty(
244 self.desc.shape.clone(),
245 &self.device,
246 self.desc.dtype.into(),
247 );
248 handles.register_float_tensor::<B>(&self.desc.id, output);
249 }
250 }
251
252 let client = get_client::<B>(&device.clone());
253 let out = client.tensor_uninitialized(shape.clone(), dtype.into());
254
255 let desc = out.to_ir_out();
256
257 client.register(
258 OperationStreams::default(),
259 OperationIr::BaseFloat(BaseOperationIr::Empty(desc.clone())),
260 EmptyOps::<B>::new(desc, device.clone()),
261 );
262
263 out
264 }
265
266 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
267 binary_float_ops!(AddOps, B::float_add);
268
269 let mut streams = OperationStreams::default();
270 streams.tensor(&lhs);
271 streams.tensor(&rhs);
272 let dtype = lhs.dtype;
273 let out = lhs
274 .client
275 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
276
277 let desc = BinaryOpIr {
278 lhs: lhs.into_ir(),
279 rhs: rhs.into_ir(),
280 out: out.to_ir_out(),
281 };
282
283 out.client.register(
284 streams,
285 OperationIr::NumericFloat(dtype, NumericOperationIr::Add(desc.clone())),
286 AddOps::<B>::new(desc),
287 );
288
289 out
290 }
291
292 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
293 scalar_float_ops!(AddOps, B::float_add_scalar);
294
295 let mut streams = OperationStreams::default();
296 streams.tensor(&lhs);
297 let dtype = lhs.dtype;
298 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
299
300 let desc = ScalarOpIr {
301 lhs: lhs.into_ir(),
302 rhs: ScalarIr::with_dtype(rhs, &dtype),
303 out: out.to_ir_out(),
304 };
305 out.client.register(
306 streams,
307 OperationIr::NumericFloat(dtype, NumericOperationIr::AddScalar(desc.clone())),
308 AddOps::<B>::new(desc),
309 );
310
311 out
312 }
313
314 fn float_clamp(
315 tensor: FloatTensor<Self>,
316 min: FloatElem<Self>,
317 max: FloatElem<Self>,
318 ) -> FloatTensor<Self> {
319 #[derive(new, Debug)]
320 struct ClampOps<B: FusionBackend> {
321 desc: ClampOpIr,
322 _b: PhantomData<B>,
323 }
324
325 impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
326 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
327 let input = handles.get_float_tensor::<B>(&self.desc.tensor);
328 let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem());
329
330 handles.register_float_tensor::<B>(&self.desc.out.id, output);
331 }
332 }
333
334 let mut streams = OperationStreams::default();
335 streams.tensor(&tensor);
336 let dtype = tensor.dtype;
337 let out = tensor
338 .client
339 .tensor_uninitialized(tensor.shape.clone(), dtype);
340
341 let desc = ClampOpIr {
342 tensor: tensor.into_ir(),
343 min: ScalarIr::with_dtype(min, &dtype),
344 max: ScalarIr::with_dtype(max, &dtype),
345 out: out.to_ir_out(),
346 };
347 out.client.register(
348 streams,
349 OperationIr::NumericFloat(dtype, NumericOperationIr::Clamp(desc.clone())),
350 ClampOps::<B>::new(desc),
351 );
352
353 out
354 }
355
356 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
357 binary_float_ops!(SubOps, B::float_sub);
358
359 let mut streams = OperationStreams::default();
360 streams.tensor(&lhs);
361 streams.tensor(&rhs);
362 let dtype = lhs.dtype;
363 let out = lhs
364 .client
365 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
366
367 let desc = BinaryOpIr {
368 lhs: lhs.into_ir(),
369 rhs: rhs.into_ir(),
370 out: out.to_ir_out(),
371 };
372 out.client.register(
373 streams,
374 OperationIr::NumericFloat(dtype, NumericOperationIr::Sub(desc.clone())),
375 SubOps::<B>::new(desc),
376 );
377
378 out
379 }
380
381 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
382 scalar_float_ops!(SubOps, B::float_sub_scalar);
383
384 let mut streams = OperationStreams::default();
385 streams.tensor(&lhs);
386 let dtype = lhs.dtype;
387 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
388 let desc = ScalarOpIr {
389 lhs: lhs.into_ir(),
390 rhs: ScalarIr::with_dtype(rhs, &dtype),
391 out: out.to_ir_out(),
392 };
393
394 out.client.register(
395 streams,
396 OperationIr::NumericFloat(dtype, NumericOperationIr::SubScalar(desc.clone())),
397 SubOps::<B>::new(desc),
398 );
399
400 out
401 }
402
403 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
404 binary_float_ops!(MulOps, B::float_mul);
405
406 let mut streams = OperationStreams::default();
407 streams.tensor(&lhs);
408 streams.tensor(&rhs);
409 let dtype = lhs.dtype;
410 let out = lhs
411 .client
412 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
413
414 let desc = BinaryOpIr {
415 lhs: lhs.into_ir(),
416 rhs: rhs.into_ir(),
417 out: out.to_ir_out(),
418 };
419 out.client.register(
420 streams,
421 OperationIr::NumericFloat(dtype, NumericOperationIr::Mul(desc.clone())),
422 MulOps::<B>::new(desc),
423 );
424
425 out
426 }
427
428 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
429 scalar_float_ops!(MulOps, B::float_mul_scalar);
430
431 let mut streams = OperationStreams::default();
432 streams.tensor(&lhs);
433 let dtype = lhs.dtype;
434 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
435
436 let desc = ScalarOpIr {
437 lhs: lhs.into_ir(),
438 rhs: ScalarIr::with_dtype(rhs, &dtype),
439 out: out.to_ir_out(),
440 };
441 out.client.register(
442 streams,
443 OperationIr::NumericFloat(dtype, NumericOperationIr::MulScalar(desc.clone())),
444 MulOps::<B>::new(desc),
445 );
446
447 out
448 }
449
450 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
451 binary_float_ops!(DivOps, B::float_div);
452
453 let mut streams = OperationStreams::default();
454 streams.tensor(&lhs);
455 streams.tensor(&rhs);
456 let dtype = lhs.dtype;
457 let out = lhs
458 .client
459 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
460
461 let desc = BinaryOpIr {
462 lhs: lhs.into_ir(),
463 rhs: rhs.into_ir(),
464 out: out.to_ir_out(),
465 };
466 out.client.register(
467 streams,
468 OperationIr::NumericFloat(dtype, NumericOperationIr::Div(desc.clone())),
469 DivOps::<B>::new(desc),
470 );
471
472 out
473 }
474
475 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
476 scalar_float_ops!(DivOps, B::float_div_scalar);
477
478 let mut streams = OperationStreams::default();
479 streams.tensor(&lhs);
480 let dtype = lhs.dtype;
481 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
482
483 let desc = ScalarOpIr {
484 lhs: lhs.into_ir(),
485 rhs: ScalarIr::with_dtype(rhs, &dtype),
486 out: out.to_ir_out(),
487 };
488 out.client.register(
489 streams,
490 OperationIr::NumericFloat(dtype, NumericOperationIr::DivScalar(desc.clone())),
491 DivOps::<B>::new(desc),
492 );
493
494 out
495 }
496
497 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
498 binary_float_ops!(ModOps, B::float_remainder);
499
500 let mut streams = OperationStreams::default();
501 streams.tensor(&lhs);
502 streams.tensor(&rhs);
503 let dtype = lhs.dtype;
504 let out = lhs
505 .client
506 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), lhs.dtype);
507
508 let desc = BinaryOpIr {
509 lhs: lhs.into_ir(),
510 rhs: rhs.into_ir(),
511 out: out.to_ir_out(),
512 };
513 out.client.register(
514 streams,
515 OperationIr::NumericFloat(dtype, NumericOperationIr::Rem(desc.clone())),
516 ModOps::<B>::new(desc),
517 );
518
519 out
520 }
521
522 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
523 scalar_float_ops!(ModOps, B::float_remainder_scalar);
524
525 let mut streams = OperationStreams::default();
526 streams.tensor(&lhs);
527 let dtype = lhs.dtype;
528 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
529
530 let desc = ScalarOpIr {
531 lhs: lhs.into_ir(),
532 rhs: ScalarIr::with_dtype(rhs, &dtype),
533 out: out.to_ir_out(),
534 };
535 out.client.register(
536 streams,
537 OperationIr::NumericFloat(dtype, NumericOperationIr::RemScalar(desc.clone())),
538 ModOps::<B>::new(desc),
539 );
540
541 out
542 }
543
544 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
545 binary_float_ops!(MatmulOps, B::float_matmul);
546
547 let mut streams = OperationStreams::default();
548 streams.tensor(&lhs);
549 streams.tensor(&rhs);
550 let dtype = lhs.dtype;
551 let shape = Shape::matmul(&lhs.shape, &rhs.shape).unwrap();
552
553 let out = lhs.client.tensor_uninitialized(shape, dtype);
554 let desc = BinaryOpIr {
555 lhs: lhs.into_ir(),
556 rhs: rhs.into_ir(),
557 out: out.to_ir_out(),
558 };
559
560 out.client.register(
561 streams,
562 OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
563 MatmulOps::<B>::new(desc),
564 );
565
566 out
567 }
568
569 fn float_cross(
570 lhs: FloatTensor<Self>,
571 rhs: FloatTensor<Self>,
572 dim: usize,
573 ) -> FloatTensor<Self> {
574 #[derive(new, Debug)]
575 struct CrossOps<B: FusionBackend> {
576 desc: CrossOpIr,
577 _b: PhantomData<B>,
578 }
579
580 impl<B: FusionBackend> Operation<B::FusionRuntime> for CrossOps<B> {
581 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
582 let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);
583 let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);
584 let output = B::float_cross(lhs, rhs, self.desc.dim);
585 handles.register_float_tensor::<B>(&self.desc.out.id, output);
586 }
587 }
588
589 let mut streams = OperationStreams::default();
590 streams.tensor(&lhs);
591 streams.tensor(&rhs);
592 let dtype = lhs.dtype;
593 let out = lhs
594 .client
595 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
596 let desc = CrossOpIr {
597 lhs: lhs.into_ir(),
598 rhs: rhs.into_ir(),
599 out: out.to_ir_out(),
600 dim,
601 };
602
603 out.client.register(
604 streams,
605 OperationIr::Float(dtype, FloatOperationIr::Cross(desc.clone())),
606 CrossOps::<B>::new(desc),
607 );
608
609 out
610 }
611
612 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
613 #[derive(new, Debug)]
614 struct SwapDimsOps<B: FusionBackend> {
615 desc: SwapDimsOpIr,
616 _b: PhantomData<B>,
617 }
618
619 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
620 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
621 let input = handles.get_float_tensor::<B>(&self.desc.input);
622 let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);
623 handles.register_float_tensor::<B>(&self.desc.out.id, output);
624 }
625 }
626
627 let mut streams = OperationStreams::default();
628 streams.tensor(&tensor);
629
630 let dtype = tensor.dtype;
631 let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
632 let mut out = tensor.client.tensor_uninitialized(shape, dtype);
633
634 let desc = SwapDimsOpIr {
635 input: tensor.into_ir(),
636 dim1,
637 dim2,
638 out: out.to_ir_out(),
639 };
640 out.client.register(
641 streams,
642 OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
643 SwapDimsOps::<B>::new(desc),
644 );
645 out.stream = StreamId::current();
646
647 out
648 }
649
650 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
651 if tensor.shape == shape {
652 return tensor;
653 }
654
655 #[derive(new, Debug)]
656 struct ReshapeDimsOps<B: FusionBackend> {
657 desc: UnaryOpIr,
658 _b: PhantomData<B>,
659 }
660
661 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
662 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
663 let input = handles.get_float_tensor::<B>(&self.desc.input);
664 let output = B::float_reshape(input, self.desc.out.shape.clone());
665 handles.register_float_tensor::<B>(&self.desc.out.id, output);
666 }
667 }
668
669 let mut streams = OperationStreams::default();
670 streams.tensor(&tensor);
671 let dtype = tensor.dtype;
672 let out = tensor.client.tensor_uninitialized(shape, dtype);
673
674 let desc = UnaryOpIr {
675 input: tensor.into_ir(),
676 out: out.to_ir_out(),
677 };
678 out.client.register(
679 streams,
680 OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
681 ReshapeDimsOps::<B>::new(desc),
682 );
683
684 out
685 }
686
687 fn float_gather(
688 dim: usize,
689 tensor: FloatTensor<Self>,
690 indices: IntTensor<Self>,
691 ) -> FloatTensor<Self> {
692 #[derive(new, Debug)]
693 struct GatherOps<B: FusionBackend> {
694 desc: GatherOpIr,
695 _b: PhantomData<B>,
696 }
697
698 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
699 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
700 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
701 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
702
703 let output = B::float_gather(self.desc.dim, tensor, indices);
704 handles.register_float_tensor::<B>(&self.desc.out.id, output);
705 }
706 }
707
708 let mut streams = OperationStreams::default();
709 streams.tensor(&tensor);
710 streams.tensor(&indices);
711
712 let dtype = tensor.dtype;
713 let shape = indices.shape.clone();
714 let out = tensor.client.tensor_uninitialized(shape, dtype);
715
716 let desc = GatherOpIr {
717 tensor: tensor.into_ir(),
718 dim,
719 indices: indices.into_ir(),
720 out: out.to_ir_out(),
721 };
722 out.client.register(
723 streams,
724 OperationIr::NumericFloat(dtype, NumericOperationIr::Gather(desc.clone())),
725 GatherOps::<B>::new(desc),
726 );
727
728 out
729 }
730
731 fn float_scatter(
732 dim: usize,
733 tensor: FloatTensor<Self>,
734 indices: IntTensor<Self>,
735 value: FloatTensor<Self>,
736 ) -> FloatTensor<Self> {
737 #[derive(new, Debug)]
738 struct ScatterOps<B: FusionBackend> {
739 desc: ScatterOpIr,
740 _b: PhantomData<B>,
741 }
742
743 impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
744 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
745 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
746 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
747 let value = handles.get_float_tensor::<B>(&self.desc.value);
748
749 let output = B::float_scatter(self.desc.dim, tensor, indices, value);
750
751 handles.register_float_tensor::<B>(&self.desc.out.id, output);
752 }
753 }
754
755 let mut streams = OperationStreams::default();
756 streams.tensor(&tensor);
757 streams.tensor(&indices);
758 streams.tensor(&value);
759
760 let shape = tensor.shape.clone();
761 let dtype = tensor.dtype;
762 let out = tensor.client.tensor_uninitialized(shape, dtype);
763
764 let desc = ScatterOpIr {
765 tensor: tensor.into_ir(),
766 dim,
767 indices: indices.into_ir(),
768 value: value.into_ir(),
769 out: out.to_ir_out(),
770 };
771 check_binary_op_types(&desc.tensor, &desc.value).unwrap();
773 out.client.register(
774 streams,
775 OperationIr::NumericFloat(dtype, NumericOperationIr::Scatter(desc.clone())),
776 ScatterOps::<B>::new(desc),
777 );
778
779 out
780 }
781
782 fn float_select(
783 tensor: FloatTensor<Self>,
784 dim: usize,
785 indices: IntTensor<Self>,
786 ) -> FloatTensor<Self> {
787 #[derive(new, Debug)]
788 struct SelectOps<B: FusionBackend> {
789 desc: SelectOpIr,
790 _b: PhantomData<B>,
791 }
792
793 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
794 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
795 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
796 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
797
798 let output = B::float_select(tensor, self.desc.dim, indices);
799
800 handles.register_float_tensor::<B>(&self.desc.out.id, output);
801 }
802 }
803
804 let mut streams = OperationStreams::default();
805 streams.tensor(&tensor);
806 streams.tensor(&indices);
807
808 let dtype = tensor.dtype;
809 let mut shape = tensor.shape.clone();
810 shape[dim] = indices.shape[0];
811 let out = tensor.client.tensor_uninitialized(shape, dtype);
812 let desc = SelectOpIr {
813 tensor: tensor.into_ir(),
814 dim,
815 indices: indices.into_ir(),
816 out: out.to_ir_out(),
817 };
818 out.client.register(
819 streams,
820 OperationIr::NumericFloat(dtype, NumericOperationIr::Select(desc.clone())),
821 SelectOps::<B>::new(desc),
822 );
823
824 out
825 }
826
827 fn float_select_assign(
828 tensor: FloatTensor<Self>,
829 dim: usize,
830 indices: IntTensor<Self>,
831 value: FloatTensor<Self>,
832 ) -> FloatTensor<Self> {
833 #[derive(new, Debug)]
834 struct SelectAssignOps<B: FusionBackend> {
835 desc: SelectAssignOpIr,
836 _b: PhantomData<B>,
837 }
838
839 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
840 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
841 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
842 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
843 let value = handles.get_float_tensor::<B>(&self.desc.value);
844
845 let output = B::float_select_assign(tensor, self.desc.dim, indices, value);
846
847 handles.register_float_tensor::<B>(&self.desc.out.id, output);
848 }
849 }
850
851 let mut streams = OperationStreams::default();
852 streams.tensor(&tensor);
853 streams.tensor(&indices);
854 streams.tensor(&value);
855
856 let dtype = tensor.dtype;
857 let shape = tensor.shape.clone();
858 let out = tensor.client.tensor_uninitialized(shape, dtype);
859
860 let desc = SelectAssignOpIr {
861 tensor: tensor.into_ir(),
862 dim,
863 indices: indices.into_ir(),
864 value: value.into_ir(),
865 out: out.to_ir_out(),
866 };
867 check_binary_op_types(&desc.tensor, &desc.value).unwrap();
869 out.client.register(
870 streams,
871 OperationIr::NumericFloat(dtype, NumericOperationIr::SelectAssign(desc.clone())),
872 SelectAssignOps::<B>::new(desc),
873 );
874
875 out
876 }
877
878 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
879 #[derive(new, Debug)]
880 struct SliceOps<B: FusionBackend> {
881 desc: SliceOpIr,
882 _b: PhantomData<B>,
883 }
884
885 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
886 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
887 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
888
889 let output = B::float_slice(tensor, self.desc.ranges.as_slice());
890
891 handles.register_float_tensor::<B>(&self.desc.out.id, output);
892 }
893 }
894 let mut streams = OperationStreams::default();
895 streams.tensor(&tensor);
896 let dtype = tensor.dtype;
897 let shape = tensor.shape.clone().slice(slices).unwrap();
898
899 let out = tensor.client.tensor_uninitialized(shape, dtype);
900
901 let desc = SliceOpIr {
902 tensor: tensor.into_ir(),
903 ranges: slices.to_vec(),
904 out: out.to_ir_out(),
905 };
906 out.client.register(
907 streams,
908 OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
909 SliceOps::<B>::new(desc),
910 );
911
912 out
913 }
914
915 fn float_slice_assign(
916 tensor: FloatTensor<Self>,
917 ranges: &[burn_tensor::Slice],
918 value: FloatTensor<Self>,
919 ) -> FloatTensor<Self> {
920 #[derive(new, Debug)]
921 struct SliceAssignOps<B: FusionBackend> {
922 desc: SliceAssignOpIr,
923 _b: PhantomData<B>,
924 }
925
926 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
927 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
928 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
929 let value = handles.get_float_tensor::<B>(&self.desc.value);
930
931 let output = B::float_slice_assign(tensor, self.desc.ranges.as_slice(), value);
932
933 handles.register_float_tensor::<B>(&self.desc.out.id, output);
934 }
935 }
936
937 let mut streams = OperationStreams::default();
938 streams.tensor(&tensor);
939 streams.tensor(&value);
940
941 let dtype = tensor.dtype;
942 let shape = tensor.shape.clone();
943 let out = tensor.client.tensor_uninitialized(shape, dtype);
944
945 let desc = SliceAssignOpIr {
946 tensor: tensor.into_ir(),
947 ranges: ranges.to_vec(),
948 value: value.into_ir(),
949 out: out.to_ir_out(),
950 };
951 check_binary_op_types(&desc.tensor, &desc.value).unwrap();
953 out.client.register(
954 streams,
955 OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc.clone())),
956 SliceAssignOps::<B>::new(desc),
957 );
958
959 out
960 }
961
962 fn float_mask_where(
963 tensor: FloatTensor<Self>,
964 mask: BoolTensor<Self>,
965 value: FloatTensor<Self>,
966 ) -> FloatTensor<Self> {
967 #[derive(new, Debug)]
968 struct MaskWhereOps<B: FusionBackend> {
969 desc: MaskWhereOpIr,
970 _b: PhantomData<B>,
971 }
972
973 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
974 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
975 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
976 let value = handles.get_float_tensor::<B>(&self.desc.value);
977 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
978
979 let output = B::float_mask_where(tensor, mask, value);
980
981 handles.register_float_tensor::<B>(&self.desc.out.id, output);
982 }
983 }
984
985 let mut streams = OperationStreams::default();
986 streams.tensor(&tensor);
987 streams.tensor(&mask);
988 streams.tensor(&value);
989
990 let dtype = tensor.dtype;
991 let out = tensor
992 .client
993 .tensor_uninitialized(tensor.shape.broadcast(&mask.shape).unwrap(), dtype);
994
995 let desc = MaskWhereOpIr {
996 tensor: tensor.into_ir(),
997 value: value.into_ir(),
998 mask: mask.into_ir(),
999 out: out.to_ir_out(),
1000 };
1001 check_binary_op_types(&desc.tensor, &desc.value).unwrap();
1003 out.client.register(
1004 streams,
1005 OperationIr::NumericFloat(dtype, NumericOperationIr::MaskWhere(desc.clone())),
1006 MaskWhereOps::<B>::new(desc),
1007 );
1008
1009 out
1010 }
1011
1012 fn float_mask_fill(
1013 tensor: FloatTensor<Self>,
1014 mask: BoolTensor<Self>,
1015 value: FloatElem<Self>,
1016 ) -> FloatTensor<Self> {
1017 #[derive(new, Debug)]
1018 struct MaskFillOps<B: FusionBackend> {
1019 desc: MaskFillOpIr,
1020 _b: PhantomData<B>,
1021 }
1022
1023 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
1024 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1025 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
1026 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
1027
1028 let output = B::float_mask_fill(tensor, mask, self.desc.value.elem());
1029
1030 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1031 }
1032 }
1033
1034 let mut streams = OperationStreams::default();
1035 streams.tensor(&tensor);
1036 streams.tensor(&mask);
1037
1038 let dtype = tensor.dtype;
1039 let shape = tensor.shape.clone();
1040 let out = tensor.client.tensor_uninitialized(shape, dtype);
1041 let desc = MaskFillOpIr {
1042 tensor: tensor.into_ir(),
1043 value: ScalarIr::with_dtype(value, &dtype),
1044 mask: mask.into_ir(),
1045 out: out.to_ir_out(),
1046 };
1047 out.client.register(
1048 streams,
1049 OperationIr::NumericFloat(dtype, NumericOperationIr::MaskFill(desc.clone())),
1050 MaskFillOps::<B>::new(desc),
1051 );
1052
1053 out
1054 }
1055
1056 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1057 binary_float_cmp_ops!(EqualOps, B::float_equal);
1058
1059 let mut streams = OperationStreams::default();
1060 streams.tensor(&lhs);
1061 streams.tensor(&rhs);
1062 let out = lhs.client.tensor_uninitialized(
1063 lhs.shape.broadcast(&rhs.shape).unwrap(),
1064 B::BoolElem::dtype(),
1065 );
1066
1067 let desc = BinaryOpIr {
1068 lhs: lhs.into_ir(),
1069 rhs: rhs.into_ir(),
1070 out: out.to_ir_out(),
1071 };
1072 out.client.register(
1073 streams,
1074 OperationIr::BaseFloat(BaseOperationIr::Equal(desc.clone())),
1075 EqualOps::<B>::new(desc),
1076 );
1077
1078 out
1079 }
1080
1081 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1082 scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem);
1083
1084 let mut streams = OperationStreams::default();
1085 streams.tensor(&lhs);
1086 let dtype = lhs.dtype;
1087 let out = lhs
1088 .client
1089 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1090
1091 let desc = ScalarOpIr {
1092 lhs: lhs.into_ir(),
1093 rhs: ScalarIr::with_dtype(rhs, &dtype),
1094 out: out.to_ir_out(),
1095 };
1096 out.client.register(
1097 streams,
1098 OperationIr::NumericFloat(dtype, NumericOperationIr::EqualElem(desc.clone())),
1099 EqualElemOps::<B>::new(desc),
1100 );
1101
1102 out
1103 }
1104
1105 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1106 binary_float_cmp_ops!(GreaterOps, B::float_greater);
1107
1108 let mut streams = OperationStreams::default();
1109 streams.tensor(&lhs);
1110 streams.tensor(&rhs);
1111 let dtype = lhs.dtype;
1112 let out = lhs.client.tensor_uninitialized(
1113 lhs.shape.broadcast(&rhs.shape).unwrap(),
1114 B::BoolElem::dtype(),
1115 );
1116
1117 let desc = BinaryOpIr {
1118 lhs: lhs.into_ir(),
1119 rhs: rhs.into_ir(),
1120 out: out.to_ir_out(),
1121 };
1122 out.client.register(
1123 streams,
1124 OperationIr::NumericFloat(dtype, NumericOperationIr::Greater(desc.clone())),
1125 GreaterOps::<B>::new(desc),
1126 );
1127
1128 out
1129 }
1130
1131 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1132 scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem);
1133
1134 let mut streams = OperationStreams::default();
1135 streams.tensor(&lhs);
1136 let dtype = lhs.dtype;
1137 let out = lhs
1138 .client
1139 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1140
1141 let desc = ScalarOpIr {
1142 lhs: lhs.into_ir(),
1143 rhs: ScalarIr::with_dtype(rhs, &dtype),
1144 out: out.to_ir_out(),
1145 };
1146 out.client.register(
1147 streams,
1148 OperationIr::NumericFloat(dtype, NumericOperationIr::GreaterElem(desc.clone())),
1149 GreaterElemOps::<B>::new(desc),
1150 );
1151
1152 out
1153 }
1154
1155 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1156 binary_float_cmp_ops!(GreaterEqualOps, B::float_greater_equal);
1157
1158 let mut streams = OperationStreams::default();
1159 streams.tensor(&lhs);
1160 streams.tensor(&rhs);
1161 let dtype = lhs.dtype;
1162 let out = lhs.client.tensor_uninitialized(
1163 lhs.shape.broadcast(&rhs.shape).unwrap(),
1164 B::BoolElem::dtype(),
1165 );
1166
1167 let desc = BinaryOpIr {
1168 lhs: lhs.into_ir(),
1169 rhs: rhs.into_ir(),
1170 out: out.to_ir_out(),
1171 };
1172 out.client.register(
1173 streams,
1174 OperationIr::NumericFloat(dtype, NumericOperationIr::GreaterEqual(desc.clone())),
1175 GreaterEqualOps::<B>::new(desc),
1176 );
1177
1178 out
1179 }
1180
1181 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1182 scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem);
1183
1184 let mut streams = OperationStreams::default();
1185 streams.tensor(&lhs);
1186 let dtype = lhs.dtype;
1187 let out = lhs
1188 .client
1189 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1190
1191 let desc = ScalarOpIr {
1192 lhs: lhs.into_ir(),
1193 rhs: ScalarIr::with_dtype(rhs, &dtype),
1194 out: out.to_ir_out(),
1195 };
1196 out.client.register(
1197 streams,
1198 OperationIr::NumericFloat(dtype, NumericOperationIr::GreaterEqualElem(desc.clone())),
1199 GreaterEqualElemOps::<B>::new(desc),
1200 );
1201
1202 out
1203 }
1204
1205 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1206 binary_float_cmp_ops!(LowerOps, B::float_lower);
1207
1208 let mut streams = OperationStreams::default();
1209 streams.tensor(&lhs);
1210 streams.tensor(&rhs);
1211 let dtype = lhs.dtype;
1212 let out = lhs.client.tensor_uninitialized(
1213 lhs.shape.broadcast(&rhs.shape).unwrap(),
1214 B::BoolElem::dtype(),
1215 );
1216
1217 let desc = BinaryOpIr {
1218 lhs: lhs.into_ir(),
1219 rhs: rhs.into_ir(),
1220 out: out.to_ir_out(),
1221 };
1222 out.client.register(
1223 streams,
1224 OperationIr::NumericFloat(dtype, NumericOperationIr::Lower(desc.clone())),
1225 LowerOps::<B>::new(desc),
1226 );
1227
1228 out
1229 }
1230
1231 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1232 scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem);
1233
1234 let mut streams = OperationStreams::default();
1235 streams.tensor(&lhs);
1236 let dtype = lhs.dtype;
1237 let out = lhs
1238 .client
1239 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1240
1241 let desc = ScalarOpIr {
1242 lhs: lhs.into_ir(),
1243 rhs: ScalarIr::with_dtype(rhs, &dtype),
1244 out: out.to_ir_out(),
1245 };
1246 out.client.register(
1247 streams,
1248 OperationIr::NumericFloat(dtype, NumericOperationIr::LowerElem(desc.clone())),
1249 LowerElemOps::<B>::new(desc),
1250 );
1251
1252 out
1253 }
1254
1255 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
1256 binary_float_cmp_ops!(LowerEqualOps, B::float_lower_equal);
1257
1258 let mut streams = OperationStreams::default();
1259 streams.tensor(&lhs);
1260 streams.tensor(&rhs);
1261 let dtype = lhs.dtype;
1262 let out = lhs.client.tensor_uninitialized(
1263 lhs.shape.broadcast(&rhs.shape).unwrap(),
1264 B::BoolElem::dtype(),
1265 );
1266
1267 let desc = BinaryOpIr {
1268 lhs: lhs.into_ir(),
1269 rhs: rhs.into_ir(),
1270 out: out.to_ir_out(),
1271 };
1272 out.client.register(
1273 streams,
1274 OperationIr::NumericFloat(dtype, NumericOperationIr::LowerEqual(desc.clone())),
1275 LowerEqualOps::<B>::new(desc),
1276 );
1277
1278 out
1279 }
1280
1281 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
1282 scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem);
1283
1284 let mut streams = OperationStreams::default();
1285 streams.tensor(&lhs);
1286 let dtype = lhs.dtype;
1287 let out = lhs
1288 .client
1289 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
1290
1291 let desc = ScalarOpIr {
1292 lhs: lhs.into_ir(),
1293 rhs: ScalarIr::with_dtype(rhs, &dtype),
1294 out: out.to_ir_out(),
1295 };
1296 out.client.register(
1297 streams,
1298 OperationIr::NumericFloat(dtype, NumericOperationIr::LowerEqualElem(desc.clone())),
1299 LowerEqualElemOps::<B>::new(desc),
1300 );
1301
1302 out
1303 }
1304
1305 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1306 unary_float_ops!(SumOps, B::float_sum, reduce);
1307
1308 let mut streams = OperationStreams::default();
1309 streams.tensor(&tensor);
1310 let dtype = tensor.dtype;
1311 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1312
1313 let desc = UnaryOpIr {
1314 input: tensor.into_ir(),
1315 out: out.to_ir_out(),
1316 };
1317 out.client.register(
1318 streams,
1319 OperationIr::NumericFloat(dtype, NumericOperationIr::Sum(desc.clone())),
1320 SumOps::<B>::new(desc),
1321 );
1322
1323 out
1324 }
1325
1326 fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
1327 reduce_float_ops!(SumDimOps, B::float_sum_dim);
1328
1329 let mut streams = OperationStreams::default();
1330 streams.tensor(&tensor);
1331 let dtype = tensor.dtype;
1332 let mut shape = tensor.shape.clone();
1333 shape[axis] = 1;
1334 let out = tensor.client.tensor_uninitialized(shape, dtype);
1335
1336 let desc = ReduceDimOpIr {
1337 input: tensor.into_ir(),
1338 out: out.to_ir_out(),
1339 axis,
1340 };
1341 out.client.register(
1342 streams,
1343 OperationIr::NumericFloat(dtype, NumericOperationIr::SumDim(desc.clone())),
1344 SumDimOps::<B>::new(desc),
1345 );
1346
1347 out
1348 }
1349
1350 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1351 unary_float_ops!(ProdOps, B::float_prod, reduce);
1352
1353 let mut streams = OperationStreams::default();
1354 streams.tensor(&tensor);
1355 let dtype = tensor.dtype;
1356 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1357
1358 let desc = UnaryOpIr {
1359 input: tensor.into_ir(),
1360 out: out.to_ir_out(),
1361 };
1362 out.client.register(
1363 streams,
1364 OperationIr::NumericFloat(dtype, NumericOperationIr::Prod(desc.clone())),
1365 ProdOps::<B>::new(desc),
1366 );
1367
1368 out
1369 }
1370
1371 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1372 reduce_float_ops!(ProdDimOps, B::float_prod_dim);
1373
1374 let mut streams = OperationStreams::default();
1375 streams.tensor(&tensor);
1376 let dtype = tensor.dtype;
1377 let mut shape = tensor.shape.clone();
1378 shape[dim] = 1;
1379 let out = tensor.client.tensor_uninitialized(shape, dtype);
1380
1381 let desc = ReduceDimOpIr {
1382 input: tensor.into_ir(),
1383 axis: dim,
1384 out: out.to_ir_out(),
1385 };
1386
1387 out.client.register(
1388 streams,
1389 OperationIr::NumericFloat(dtype, NumericOperationIr::ProdDim(desc.clone())),
1390 ProdDimOps::<B>::new(desc),
1391 );
1392
1393 out
1394 }
1395
1396 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1397 unary_float_ops!(MeanOps, B::float_mean, reduce);
1398
1399 let mut streams = OperationStreams::default();
1400 streams.tensor(&tensor);
1401 let dtype = tensor.dtype;
1402 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1403
1404 let desc = UnaryOpIr {
1405 input: tensor.into_ir(),
1406 out: out.to_ir_out(),
1407 };
1408 out.client.register(
1409 streams,
1410 OperationIr::NumericFloat(dtype, NumericOperationIr::Mean(desc.clone())),
1411 MeanOps::<B>::new(desc),
1412 );
1413
1414 out
1415 }
1416
1417 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1418 reduce_float_ops!(MeanDimOps, B::float_mean_dim);
1419
1420 let mut streams = OperationStreams::default();
1421 streams.tensor(&tensor);
1422 let dtype = tensor.dtype;
1423 let mut shape = tensor.shape.clone();
1424 shape[dim] = 1;
1425 let out = tensor.client.tensor_uninitialized(shape, dtype);
1426
1427 let desc = ReduceDimOpIr {
1428 input: tensor.into_ir(),
1429 axis: dim,
1430 out: out.to_ir_out(),
1431 };
1432 out.client.register(
1433 streams,
1434 OperationIr::NumericFloat(dtype, NumericOperationIr::MeanDim(desc.clone())),
1435 MeanDimOps::<B>::new(desc),
1436 );
1437
1438 out
1439 }
1440
1441 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1442 #[derive(new, Debug)]
1443 struct CumsumOps<B: FusionBackend> {
1444 desc: DimOpIr,
1445 _b: PhantomData<B>,
1446 }
1447
1448 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1449 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1450 let input = handles.get_float_tensor::<B>(&self.desc.input);
1451 let output = B::float_cumsum(input, self.desc.axis);
1452 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1453 }
1454 }
1455
1456 let mut streams = OperationStreams::default();
1457 streams.tensor(&tensor);
1458 let dtype = tensor.dtype;
1459 let shape = tensor.shape.clone();
1460 let out = tensor.client.tensor_uninitialized(shape, dtype);
1461
1462 let desc = DimOpIr {
1463 input: tensor.into_ir(),
1464 out: out.to_ir_out(),
1465 axis: dim,
1466 };
1467
1468 out.client.register(
1469 streams,
1470 OperationIr::BaseFloat(BaseOperationIr::CumSum(desc.clone())),
1471 CumsumOps::<B>::new(desc),
1472 );
1473
1474 out
1475 }
1476
1477 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1478 #[derive(new, Debug)]
1479 struct CumprodOps<B: FusionBackend> {
1480 desc: DimOpIr,
1481 _b: PhantomData<B>,
1482 }
1483
1484 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1485 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1486 let input = handles.get_float_tensor::<B>(&self.desc.input);
1487 let output = B::float_cumprod(input, self.desc.axis);
1488 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1489 }
1490 }
1491
1492 let mut streams = OperationStreams::default();
1493 streams.tensor(&tensor);
1494 let dtype = tensor.dtype;
1495 let shape = tensor.shape.clone();
1496 let out = tensor.client.tensor_uninitialized(shape, dtype);
1497
1498 let desc = DimOpIr {
1499 input: tensor.into_ir(),
1500 out: out.to_ir_out(),
1501 axis: dim,
1502 };
1503
1504 out.client.register(
1505 streams,
1506 OperationIr::BaseFloat(BaseOperationIr::CumProd(desc.clone())),
1507 CumprodOps::<B>::new(desc),
1508 );
1509
1510 out
1511 }
1512
1513 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1514 #[derive(new, Debug)]
1515 struct CumminOps<B: FusionBackend> {
1516 desc: DimOpIr,
1517 _b: PhantomData<B>,
1518 }
1519
1520 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1521 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1522 let input = handles.get_float_tensor::<B>(&self.desc.input);
1523 let output = B::float_cummin(input, self.desc.axis);
1524 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1525 }
1526 }
1527
1528 let mut streams = OperationStreams::default();
1529 streams.tensor(&tensor);
1530 let dtype = tensor.dtype;
1531 let shape = tensor.shape.clone();
1532 let out = tensor.client.tensor_uninitialized(shape, dtype);
1533
1534 let desc = DimOpIr {
1535 input: tensor.into_ir(),
1536 out: out.to_ir_out(),
1537 axis: dim,
1538 };
1539
1540 out.client.register(
1541 streams,
1542 OperationIr::BaseFloat(BaseOperationIr::CumMin(desc.clone())),
1543 CumminOps::<B>::new(desc),
1544 );
1545
1546 out
1547 }
1548
1549 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1550 #[derive(new, Debug)]
1551 struct CummaxOps<B: FusionBackend> {
1552 desc: DimOpIr,
1553 _b: PhantomData<B>,
1554 }
1555
1556 impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1557 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1558 let input = handles.get_float_tensor::<B>(&self.desc.input);
1559 let output = B::float_cummax(input, self.desc.axis);
1560 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1561 }
1562 }
1563
1564 let mut streams = OperationStreams::default();
1565 streams.tensor(&tensor);
1566 let dtype = tensor.dtype;
1567 let shape = tensor.shape.clone();
1568 let out = tensor.client.tensor_uninitialized(shape, dtype);
1569
1570 let desc = DimOpIr {
1571 input: tensor.into_ir(),
1572 out: out.to_ir_out(),
1573 axis: dim,
1574 };
1575
1576 out.client.register(
1577 streams,
1578 OperationIr::BaseFloat(BaseOperationIr::CumMax(desc.clone())),
1579 CummaxOps::<B>::new(desc),
1580 );
1581
1582 out
1583 }
1584
1585 fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
1586 unary_float_ops!(ExpOps, B::float_exp);
1587
1588 let mut streams = OperationStreams::default();
1589 streams.tensor(&lhs);
1590 let dtype = lhs.dtype;
1591 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
1592
1593 let desc = UnaryOpIr {
1594 input: lhs.into_ir(),
1595 out: out.to_ir_out(),
1596 };
1597 out.client.register(
1598 streams,
1599 OperationIr::Float(dtype, FloatOperationIr::Exp(desc.clone())),
1600 ExpOps::<B>::new(desc),
1601 );
1602
1603 out
1604 }
1605
1606 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1607 unary_float_ops!(LogOps, B::float_log);
1608
1609 let mut streams = OperationStreams::default();
1610 streams.tensor(&tensor);
1611 let dtype = tensor.dtype;
1612 let out = tensor
1613 .client
1614 .tensor_uninitialized(tensor.shape.clone(), dtype);
1615
1616 let desc = UnaryOpIr {
1617 input: tensor.into_ir(),
1618 out: out.to_ir_out(),
1619 };
1620 out.client.register(
1621 streams,
1622 OperationIr::Float(dtype, FloatOperationIr::Log(desc.clone())),
1623 LogOps::<B>::new(desc),
1624 );
1625
1626 out
1627 }
1628
1629 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1630 unary_float_ops!(Log1pOps, B::float_log1p);
1631
1632 let mut streams = OperationStreams::default();
1633 streams.tensor(&tensor);
1634 let dtype = tensor.dtype;
1635 let out = tensor
1636 .client
1637 .tensor_uninitialized(tensor.shape.clone(), dtype);
1638
1639 let desc = UnaryOpIr {
1640 input: tensor.into_ir(),
1641 out: out.to_ir_out(),
1642 };
1643 out.client.register(
1644 streams,
1645 OperationIr::Float(dtype, FloatOperationIr::Log1p(desc.clone())),
1646 Log1pOps::<B>::new(desc),
1647 );
1648
1649 out
1650 }
1651
1652 fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
1653 scalar_float_ops!(PowfOps, B::float_powf_scalar);
1654
1655 let mut streams = OperationStreams::default();
1656 streams.tensor(&lhs);
1657 let dtype = lhs.dtype;
1658 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
1659
1660 let desc = ScalarOpIr {
1661 lhs: lhs.into_ir(),
1662 rhs: ScalarIr::with_dtype(rhs, &dtype),
1663 out: out.to_ir_out(),
1664 };
1665 out.client.register(
1666 streams,
1667 OperationIr::Float(dtype, FloatOperationIr::PowfScalar(desc.clone())),
1668 PowfOps::<B>::new(desc),
1669 );
1670
1671 out
1672 }
1673
1674 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1675 unary_float_ops!(SqrtOps, B::float_sqrt);
1676
1677 let mut streams = OperationStreams::default();
1678 streams.tensor(&tensor);
1679 let dtype = tensor.dtype;
1680 let out = tensor
1681 .client
1682 .tensor_uninitialized(tensor.shape.clone(), dtype);
1683
1684 let desc = UnaryOpIr {
1685 input: tensor.into_ir(),
1686 out: out.to_ir_out(),
1687 };
1688 out.client.register(
1689 streams,
1690 OperationIr::Float(dtype, FloatOperationIr::Sqrt(desc.clone())),
1691 SqrtOps::<B>::new(desc),
1692 );
1693
1694 out
1695 }
1696
1697 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1698 unary_float_ops!(AbsOps, B::float_abs);
1699
1700 let mut streams = OperationStreams::default();
1701 streams.tensor(&tensor);
1702 let dtype = tensor.dtype;
1703 let out = tensor
1704 .client
1705 .tensor_uninitialized(tensor.shape.clone(), dtype);
1706
1707 let desc = UnaryOpIr {
1708 input: tensor.into_ir(),
1709 out: out.to_ir_out(),
1710 };
1711 out.client.register(
1712 streams,
1713 OperationIr::NumericFloat(dtype, NumericOperationIr::Abs(desc.clone())),
1714 AbsOps::<B>::new(desc),
1715 );
1716
1717 out
1718 }
1719
1720 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1721 unary_float_ops!(CosOps, B::float_cos);
1722
1723 let mut streams = OperationStreams::default();
1724 streams.tensor(&tensor);
1725 let dtype = tensor.dtype;
1726 let out = tensor
1727 .client
1728 .tensor_uninitialized(tensor.shape.clone(), dtype);
1729
1730 let desc = UnaryOpIr {
1731 input: tensor.into_ir(),
1732 out: out.to_ir_out(),
1733 };
1734 out.client.register(
1735 streams,
1736 OperationIr::Float(dtype, FloatOperationIr::Cos(desc.clone())),
1737 CosOps::<B>::new(desc),
1738 );
1739
1740 out
1741 }
1742
1743 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1744 unary_float_ops!(SinOps, B::float_sin);
1745
1746 let mut streams = OperationStreams::default();
1747 streams.tensor(&tensor);
1748 let dtype = tensor.dtype;
1749 let out = tensor
1750 .client
1751 .tensor_uninitialized(tensor.shape.clone(), dtype);
1752
1753 let desc = UnaryOpIr {
1754 input: tensor.into_ir(),
1755 out: out.to_ir_out(),
1756 };
1757 out.client.register(
1758 streams,
1759 OperationIr::Float(dtype, FloatOperationIr::Sin(desc.clone())),
1760 SinOps::<B>::new(desc),
1761 );
1762
1763 out
1764 }
1765
1766 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1767 unary_float_ops!(TanhOps, B::float_tanh);
1768
1769 let mut streams = OperationStreams::default();
1770 streams.tensor(&tensor);
1771 let dtype = tensor.dtype;
1772 let out = tensor
1773 .client
1774 .tensor_uninitialized(tensor.shape.clone(), dtype);
1775
1776 let desc = UnaryOpIr {
1777 input: tensor.into_ir(),
1778 out: out.to_ir_out(),
1779 };
1780 out.client.register(
1781 streams,
1782 OperationIr::Float(dtype, FloatOperationIr::Tanh(desc.clone())),
1783 TanhOps::<B>::new(desc),
1784 );
1785
1786 out
1787 }
1788
1789 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1790 unary_float_ops!(Recip, B::float_recip);
1791
1792 let mut streams = OperationStreams::default();
1793 streams.tensor(&tensor);
1794 let dtype = tensor.dtype;
1795 let out = tensor
1796 .client
1797 .tensor_uninitialized(tensor.shape.clone(), dtype);
1798 let desc = UnaryOpIr {
1799 input: tensor.into_ir(),
1800 out: out.to_ir_out(),
1801 };
1802 out.client.register(
1803 streams,
1804 OperationIr::Float(dtype, FloatOperationIr::Recip(desc.clone())),
1805 Recip::<B>::new(desc),
1806 );
1807
1808 out
1809 }
1810
1811 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1812 unary_float_ops!(TanhOps, B::float_erf);
1813
1814 let mut streams = OperationStreams::default();
1815 streams.tensor(&tensor);
1816 let dtype = tensor.dtype;
1817 let out = tensor
1818 .client
1819 .tensor_uninitialized(tensor.shape.clone(), dtype);
1820
1821 let desc = UnaryOpIr {
1822 input: tensor.into_ir(),
1823 out: out.to_ir_out(),
1824 };
1825 out.client.register(
1826 streams,
1827 OperationIr::Float(dtype, FloatOperationIr::Erf(desc.clone())),
1828 TanhOps::<B>::new(desc),
1829 );
1830
1831 out
1832 }
1833
1834 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1835 #[derive(new, Debug)]
1836 struct CatOps<B: FusionBackend> {
1837 desc: CatOpIr,
1838 _b: PhantomData<B>,
1839 }
1840
1841 impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
1842 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1843 let tensors = self
1844 .desc
1845 .tensors
1846 .iter()
1847 .map(|tensor| handles.get_float_tensor::<B>(tensor))
1848 .collect();
1849
1850 let output = B::float_cat(tensors, self.desc.dim);
1851
1852 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1853 }
1854 }
1855
1856 let tensor_first = tensors.first().unwrap();
1857 let dtype = tensor_first.dtype;
1858 let client = tensor_first.client.clone();
1859
1860 let mut streams = OperationStreams::default();
1861 tensors.iter().for_each(|tensor| streams.tensor(tensor));
1862
1863 let shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap();
1865
1866 let out = client.tensor_uninitialized(shape, dtype);
1867
1868 let desc = CatOpIr {
1869 tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
1870 dim,
1871 out: out.to_ir_out(),
1872 };
1873 desc.tensors
1874 .windows(2)
1875 .for_each(|desc| check_binary_op_types(&desc[0], &desc[1]).unwrap());
1876 client.register(
1877 streams,
1878 OperationIr::BaseFloat(BaseOperationIr::Cat(desc.clone())),
1879 CatOps::<B>::new(desc),
1880 );
1881
1882 out
1883 }
1884
1885 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1886 reduce_float2int_ops!(ArgMaxOps, B::float_argmax);
1887
1888 let mut streams = OperationStreams::default();
1889 streams.tensor(&tensor);
1890 let dtype = tensor.dtype;
1891 let mut shape = tensor.shape.clone();
1892 shape[dim] = 1;
1893 let out = tensor
1894 .client
1895 .tensor_uninitialized(shape, B::IntElem::dtype());
1896
1897 let desc = ReduceDimOpIr {
1898 input: tensor.into_ir(),
1899 axis: dim,
1900 out: out.to_ir_out(),
1901 };
1902 out.client.register(
1903 streams,
1904 OperationIr::NumericFloat(dtype, NumericOperationIr::ArgMax(desc.clone())),
1905 ArgMaxOps::<B>::new(desc),
1906 );
1907
1908 out
1909 }
1910
1911 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1912 #[derive(new, Debug)]
1913 struct RepeatDimOps<B: FusionBackend> {
1914 desc: RepeatDimOpIr,
1915 _b: PhantomData<B>,
1916 }
1917
1918 impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1919 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1920 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
1921
1922 let output = B::float_repeat_dim(tensor, self.desc.dim, self.desc.times);
1923
1924 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1925 }
1926 }
1927
1928 let mut streams = OperationStreams::default();
1929 streams.tensor(&tensor);
1930 let shape = tensor.shape.clone().repeat(dim, times);
1931 let out = tensor.client.tensor_uninitialized(shape, tensor.dtype);
1932
1933 let desc = RepeatDimOpIr {
1934 tensor: tensor.into_ir(),
1935 dim,
1936 times,
1937 out: out.to_ir_out(),
1938 };
1939 out.client.register(
1940 streams,
1941 OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc.clone())),
1942 RepeatDimOps::<B>::new(desc),
1943 );
1944
1945 out
1946 }
1947
1948 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1949 reduce_float2int_ops!(ArgMinOps, B::float_argmin);
1950
1951 let mut streams = OperationStreams::default();
1952 streams.tensor(&tensor);
1953 let mut shape = tensor.shape.clone();
1954 shape[dim] = 1;
1955 let dtype = tensor.dtype;
1956 let out = tensor
1957 .client
1958 .tensor_uninitialized(shape, B::IntElem::dtype());
1959
1960 let desc = ReduceDimOpIr {
1961 input: tensor.into_ir(),
1962 axis: dim,
1963 out: out.to_ir_out(),
1964 };
1965 out.client.register(
1966 streams,
1967 OperationIr::NumericFloat(dtype, NumericOperationIr::ArgMin(desc.clone())),
1968 ArgMinOps::<B>::new(desc),
1969 );
1970
1971 out
1972 }
1973
1974 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1975 unary_float_ops!(MaxOps, B::float_max, reduce);
1976
1977 let mut streams = OperationStreams::default();
1978 streams.tensor(&tensor);
1979 let dtype = tensor.dtype;
1980 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1981
1982 let desc = UnaryOpIr {
1983 input: tensor.into_ir(),
1984 out: out.to_ir_out(),
1985 };
1986 out.client.register(
1987 streams,
1988 OperationIr::NumericFloat(dtype, NumericOperationIr::Max(desc.clone())),
1989 MaxOps::<B>::new(desc),
1990 );
1991
1992 out
1993 }
1994
1995 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1996 reduce_float_ops!(MaxDimOps, B::float_max_dim);
1997
1998 let mut streams = OperationStreams::default();
1999 streams.tensor(&tensor);
2000 let mut shape = tensor.shape.clone();
2001 let dtype = tensor.dtype;
2002 shape[dim] = 1;
2003 let out = tensor.client.tensor_uninitialized(shape, dtype);
2004
2005 let desc = ReduceDimOpIr {
2006 input: tensor.into_ir(),
2007 axis: dim,
2008 out: out.to_ir_out(),
2009 };
2010 out.client.register(
2011 streams,
2012 OperationIr::NumericFloat(dtype, NumericOperationIr::MaxDim(desc.clone())),
2013 MaxDimOps::<B>::new(desc),
2014 );
2015
2016 out
2017 }
2018
2019 fn float_max_dim_with_indices(
2020 tensor: FloatTensor<Self>,
2021 dim: usize,
2022 ) -> (FloatTensor<Self>, IntTensor<Self>) {
2023 #[derive(new, Debug)]
2024 struct MaxDimWithIndicesOps<B: FusionBackend> {
2025 desc: ReduceDimWithIndicesOpIr,
2026 _b: PhantomData<B>,
2027 }
2028
2029 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
2030 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2031 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2032 let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim);
2033
2034 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2035 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2036 }
2037 }
2038
2039 let mut streams = OperationStreams::default();
2040 streams.tensor(&tensor);
2041
2042 let mut shape = tensor.shape.clone();
2043 shape[dim] = 1;
2044 let dtype = tensor.dtype;
2045 let client = tensor.client.clone();
2046 let out = client.tensor_uninitialized(shape.clone(), dtype);
2047 let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype());
2048
2049 let desc = ReduceDimWithIndicesOpIr {
2050 tensor: tensor.into_ir(),
2051 dim,
2052 out: out.to_ir_out(),
2053 out_indices: out_indices.to_ir_out(),
2054 };
2055 client.register(
2056 streams,
2057 OperationIr::NumericFloat(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())),
2058 MaxDimWithIndicesOps::<B>::new(desc),
2059 );
2060
2061 (out, out_indices)
2062 }
2063
2064 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2065 unary_float_ops!(MinOps, B::float_min, reduce);
2066
2067 let mut streams = OperationStreams::default();
2068 streams.tensor(&tensor);
2069 let dtype = tensor.dtype;
2070 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
2071
2072 let desc = UnaryOpIr {
2073 input: tensor.into_ir(),
2074 out: out.to_ir_out(),
2075 };
2076 out.client.register(
2077 streams,
2078 OperationIr::NumericFloat(dtype, NumericOperationIr::Min(desc.clone())),
2079 MinOps::<B>::new(desc),
2080 );
2081
2082 out
2083 }
2084
2085 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2086 reduce_float_ops!(MinDimOps, B::float_min_dim);
2087
2088 let mut streams = OperationStreams::default();
2089 streams.tensor(&tensor);
2090 let dtype = tensor.dtype;
2091 let mut shape = tensor.shape.clone();
2092 shape[dim] = 1;
2093 let out = tensor.client.tensor_uninitialized(shape, dtype);
2094
2095 let desc = ReduceDimOpIr {
2096 input: tensor.into_ir(),
2097 axis: dim,
2098 out: out.to_ir_out(),
2099 };
2100 out.client.register(
2101 streams,
2102 OperationIr::NumericFloat(dtype, NumericOperationIr::MinDim(desc.clone())),
2103 MinDimOps::<B>::new(desc),
2104 );
2105
2106 out
2107 }
2108
2109 fn float_min_dim_with_indices(
2110 tensor: FloatTensor<Self>,
2111 dim: usize,
2112 ) -> (FloatTensor<Self>, IntTensor<Self>) {
2113 #[derive(new, Debug)]
2114 struct MinDimWithIndicesOps<B: FusionBackend> {
2115 desc: ReduceDimWithIndicesOpIr,
2116 _b: PhantomData<B>,
2117 }
2118
2119 impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
2120 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2121 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
2122 let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim);
2123
2124 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2125 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
2126 }
2127 }
2128
2129 let mut streams = OperationStreams::default();
2130 streams.tensor(&tensor);
2131 let dtype = tensor.dtype;
2132 let mut shape = tensor.shape.clone();
2133 shape[dim] = 1;
2134 let client = tensor.client.clone();
2135 let out = client.tensor_uninitialized(shape.clone(), dtype);
2136 let out_indices = client.tensor_uninitialized(shape, B::IntElem::dtype());
2137
2138 let desc = ReduceDimWithIndicesOpIr {
2139 tensor: tensor.into_ir(),
2140 dim,
2141 out: out.to_ir_out(),
2142 out_indices: out_indices.to_ir_out(),
2143 };
2144 client.register(
2145 streams,
2146 OperationIr::NumericFloat(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())),
2147 MinDimWithIndicesOps::<B>::new(desc),
2148 );
2149
2150 (out, out_indices)
2151 }
2152
2153 fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2154 unary_float_ops!(MaxAbsOps, B::float_max_abs, reduce);
2155
2156 let mut streams = OperationStreams::default();
2157 streams.tensor(&tensor);
2158 let dtype = tensor.dtype;
2159 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
2160
2161 let desc = UnaryOpIr {
2162 input: tensor.into_ir(),
2163 out: out.to_ir_out(),
2164 };
2165 out.client.register(
2166 streams,
2167 OperationIr::NumericFloat(dtype, NumericOperationIr::MaxAbs(desc.clone())),
2168 MaxAbsOps::<B>::new(desc),
2169 );
2170
2171 out
2172 }
2173
2174 fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2175 reduce_float_ops!(MaxAbsDimOps, B::float_max_abs_dim);
2176
2177 let mut streams = OperationStreams::default();
2178 streams.tensor(&tensor);
2179 let mut shape = tensor.shape.clone();
2180 let dtype = tensor.dtype;
2181 shape[dim] = 1;
2182 let out = tensor.client.tensor_uninitialized(shape, dtype);
2183
2184 let desc = ReduceDimOpIr {
2185 input: tensor.into_ir(),
2186 axis: dim,
2187 out: out.to_ir_out(),
2188 };
2189 out.client.register(
2190 streams,
2191 OperationIr::NumericFloat(dtype, NumericOperationIr::MaxAbsDim(desc.clone())),
2192 MaxAbsDimOps::<B>::new(desc),
2193 );
2194
2195 out
2196 }
2197
2198 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
2199 binary_float_ops!(PowOps, B::float_powf);
2200 let mut streams = OperationStreams::default();
2201 streams.tensor(&lhs);
2202 streams.tensor(&rhs);
2203 let dtype = lhs.dtype;
2204
2205 let out = lhs
2206 .client
2207 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2208
2209 let desc = BinaryOpIr {
2210 lhs: lhs.into_ir(),
2211 rhs: rhs.into_ir(),
2212 out: out.to_ir_out(),
2213 };
2214 out.client.register(
2215 streams,
2216 OperationIr::NumericFloat(dtype, NumericOperationIr::Powf(desc.clone())),
2217 PowOps::<B>::new(desc),
2218 );
2219
2220 out
2221 }
2222
2223 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2224 #[derive(new, Debug)]
2225 struct PermuteDimsOps<B: FusionBackend> {
2226 desc: PermuteOpIr,
2227 _b: PhantomData<B>,
2228 }
2229
2230 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
2231 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2232 let input = handles.get_float_tensor::<B>(&self.desc.input);
2233 let output = B::float_permute(input, self.desc.axes.as_slice());
2234 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2235 }
2236 }
2237
2238 let mut streams = OperationStreams::default();
2239 streams.tensor(&tensor);
2240
2241 let shape = tensor.shape.clone().permute(axes).unwrap();
2243 let out = tensor.client.tensor_uninitialized(shape, tensor.dtype);
2244
2245 let desc = PermuteOpIr {
2246 input: tensor.into_ir(),
2247 axes: axes.to_vec(),
2248 out: out.to_ir_out(),
2249 };
2250
2251 out.client.register(
2252 streams,
2253 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
2254 PermuteDimsOps::<B>::new(desc),
2255 );
2256
2257 out
2258 }
2259
2260 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
2261 #[derive(new, Debug)]
2262 struct ExpandOps<B: FusionBackend> {
2263 desc: ExpandOpIr,
2264 _b: PhantomData<B>,
2265 }
2266
2267 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
2268 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2269 let input = handles.get_float_tensor::<B>(&self.desc.input);
2270 let output = B::float_expand(input, self.desc.shape.clone());
2271
2272 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2273 }
2274 }
2275
2276 let mut streams = OperationStreams::default();
2277 streams.tensor(&tensor);
2278
2279 let out = tensor
2280 .client
2281 .tensor_uninitialized(shape.clone(), tensor.dtype);
2282
2283 let desc = ExpandOpIr {
2284 input: tensor.into_ir(),
2285 shape,
2286 out: out.to_ir_out(),
2287 };
2288
2289 out.client.register(
2290 streams,
2291 OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
2292 ExpandOps::<B>::new(desc),
2293 );
2294
2295 out
2296 }
2297
2298 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
2299 #[derive(new, Debug)]
2300 struct FlipOps<B: FusionBackend> {
2301 desc: FlipOpIr,
2302 _b: PhantomData<B>,
2303 }
2304
2305 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
2306 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2307 let input = handles.get_float_tensor::<B>(&self.desc.input);
2308 let output = B::float_flip(input, &self.desc.axes);
2309 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2310 }
2311 }
2312
2313 let mut streams = OperationStreams::default();
2314 streams.tensor(&tensor);
2315 let out = tensor
2316 .client
2317 .tensor_uninitialized(tensor.shape.clone(), tensor.dtype);
2318
2319 let desc = FlipOpIr {
2320 input: tensor.into_ir(),
2321 axes: axes.to_vec(),
2322 out: out.to_ir_out(),
2323 };
2324
2325 out.client.register(
2326 streams,
2327 OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
2328 FlipOps::<B>::new(desc),
2329 );
2330
2331 out
2332 }
2333
2334 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2335 unary_float_ops!(RoundOps, B::float_round);
2336
2337 let mut streams = OperationStreams::default();
2338 streams.tensor(&tensor);
2339 let dtype = tensor.dtype;
2340 let out = tensor
2341 .client
2342 .tensor_uninitialized(tensor.shape.clone(), dtype);
2343
2344 let desc = UnaryOpIr {
2345 input: tensor.into_ir(),
2346 out: out.to_ir_out(),
2347 };
2348 out.client.register(
2349 streams,
2350 OperationIr::Float(dtype, FloatOperationIr::Round(desc.clone())),
2351 RoundOps::<B>::new(desc),
2352 );
2353
2354 out
2355 }
2356
2357 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2358 unary_float_ops!(FloorOps, B::float_floor);
2359
2360 let mut streams = OperationStreams::default();
2361 streams.tensor(&tensor);
2362 let dtype = tensor.dtype;
2363 let out = tensor
2364 .client
2365 .tensor_uninitialized(tensor.shape.clone(), dtype);
2366
2367 let desc = UnaryOpIr {
2368 input: tensor.into_ir(),
2369 out: out.to_ir_out(),
2370 };
2371 out.client.register(
2372 streams,
2373 OperationIr::Float(dtype, FloatOperationIr::Floor(desc.clone())),
2374 FloorOps::<B>::new(desc),
2375 );
2376
2377 out
2378 }
2379
2380 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2381 unary_float_ops!(CeilOps, B::float_ceil);
2382
2383 let mut streams = OperationStreams::default();
2384 streams.tensor(&tensor);
2385 let dtype = tensor.dtype;
2386 let out = tensor
2387 .client
2388 .tensor_uninitialized(tensor.shape.clone(), dtype);
2389
2390 let desc = UnaryOpIr {
2391 input: tensor.into_ir(),
2392 out: out.to_ir_out(),
2393 };
2394 out.client.register(
2395 streams,
2396 OperationIr::Float(dtype, FloatOperationIr::Ceil(desc.clone())),
2397 CeilOps::<B>::new(desc),
2398 );
2399
2400 out
2401 }
2402
2403 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
2404 unary_float_ops!(TruncOps, B::float_trunc);
2405
2406 let mut streams = OperationStreams::default();
2407 streams.tensor(&tensor);
2408 let dtype = tensor.dtype;
2409 let out = tensor
2410 .client
2411 .tensor_uninitialized(tensor.shape.clone(), dtype);
2412
2413 let desc = UnaryOpIr {
2414 input: tensor.into_ir(),
2415 out: out.to_ir_out(),
2416 };
2417 out.client.register(
2418 streams,
2419 OperationIr::Float(dtype, FloatOperationIr::Trunc(desc.clone())),
2420 TruncOps::<B>::new(desc),
2421 );
2422
2423 out
2424 }
2425
2426 fn float_cast(tensor: FloatTensor<Self>, dtype: burn_tensor::FloatDType) -> FloatTensor<Self> {
2427 #[derive(new, Debug)]
2428 struct CastOps<B: FusionBackend> {
2429 desc: UnaryOpIr,
2430 dtype: burn_tensor::FloatDType,
2431 _b: PhantomData<B>,
2432 }
2433
2434 impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2435 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2436 let input = handles.get_float_tensor::<B>(&self.desc.input);
2437 let output: B::FloatTensorPrimitive = B::float_cast(input, self.dtype);
2438 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2439 }
2440 }
2441
2442 let mut streams = OperationStreams::default();
2443 streams.tensor(&tensor);
2444 let out = tensor
2445 .client
2446 .tensor_uninitialized(tensor.shape.clone(), dtype.into());
2447
2448 let desc = UnaryOpIr {
2449 input: tensor.into_ir(),
2450 out: out.to_ir_out(),
2451 };
2452 out.client.register(
2453 streams,
2454 OperationIr::BaseFloat(BaseOperationIr::Cast(desc.clone())),
2455 CastOps::<B>::new(desc, dtype),
2456 );
2457
2458 out
2459 }
2460
2461 fn float_unfold(
2462 tensor: FloatTensor<Self>,
2463 dim: usize,
2464 size: usize,
2465 step: usize,
2466 ) -> FloatTensor<Self> {
2467 #[derive(new, Debug)]
2468 struct UnfoldOps<B: FusionBackend> {
2469 desc: UnfoldOpIr,
2470 _b: PhantomData<B>,
2471 }
2472
2473 impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2474 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2475 let input = handles.get_float_tensor::<B>(&self.desc.input);
2476 let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2477
2478 handles.register_float_tensor::<B>(&self.desc.out.id, output);
2479 }
2480 }
2481
2482 let mut streams = OperationStreams::default();
2483 streams.tensor(&tensor);
2484
2485 let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
2486
2487 let out = tensor
2488 .client
2489 .tensor_uninitialized(Shape::from(shape), tensor.dtype);
2490
2491 let desc = UnfoldOpIr {
2492 input: tensor.into_ir(),
2493 out: out.to_ir_out(),
2494 dim,
2495 size,
2496 step,
2497 };
2498
2499 out.client.register(
2500 streams,
2501 OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())),
2502 UnfoldOps::<B>::new(desc),
2503 );
2504
2505 out
2506 }
2507
2508 fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
2509 #[derive(new, Debug)]
2510 struct IsNanOps<B: FusionBackend> {
2511 desc: UnaryOpIr,
2512 _b: PhantomData<B>,
2513 }
2514 impl<B: FusionBackend> Operation<B::FusionRuntime> for IsNanOps<B> {
2515 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2516 let input = handles.get_float_tensor::<B>(&self.desc.input);
2517 let output = B::float_is_nan(input);
2518 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2519 }
2520 }
2521
2522 let mut streams = OperationStreams::default();
2523 streams.tensor(&tensor);
2524 let out = tensor
2525 .client
2526 .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
2527
2528 let dtype = tensor.dtype;
2529 let desc = UnaryOpIr {
2530 input: tensor.into_ir(),
2531 out: out.to_ir_out(),
2532 };
2533 out.client.register(
2534 streams,
2535 OperationIr::Float(dtype, FloatOperationIr::IsNan(desc.clone())),
2536 IsNanOps::<B>::new(desc),
2537 );
2538
2539 out
2540 }
2541
2542 fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {
2543 #[derive(new, Debug)]
2544 struct IsInfOps<B: FusionBackend> {
2545 desc: UnaryOpIr,
2546 _b: PhantomData<B>,
2547 }
2548 impl<B: FusionBackend> Operation<B::FusionRuntime> for IsInfOps<B> {
2549 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2550 let input = handles.get_float_tensor::<B>(&self.desc.input);
2551 let output = B::float_is_inf(input);
2552 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
2553 }
2554 }
2555
2556 let mut streams = OperationStreams::default();
2557 streams.tensor(&tensor);
2558 let out = tensor
2559 .client
2560 .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
2561
2562 let dtype = tensor.dtype;
2563 let desc = UnaryOpIr {
2564 input: tensor.into_ir(),
2565 out: out.to_ir_out(),
2566 };
2567 out.client.register(
2568 streams,
2569 OperationIr::Float(dtype, FloatOperationIr::IsInf(desc.clone())),
2570 IsInfOps::<B>::new(desc),
2571 );
2572
2573 out
2574 }
2575}