1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3use burn_backend::{Scalar, tensor::FloatElem};
4use burn_std::{BoolDType, IntDType};
5
6use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
7use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor};
8use burn_backend::{Distribution, FloatDType, Shape, Slice, TensorData, ops::FloatTensorOps};
9use burn_ir::{
10 BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, CrossOpIr, DimOpIr,
11 FlipOpIr, FloatOperationIr, FullOpIr, GatherNdOpIr, GatherOpIr, InitOperationIr, MaskFillOpIr,
12 MaskWhereOpIr, MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr,
13 RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr,
14 ScatterNdOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr,
15 SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
16};
17
18impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
19 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
20 let client = get_client::<R>(device);
21 let out = client.register_tensor_data(data);
22 let desc = InitOperationIr {
23 out: out.to_ir_out(),
24 };
25
26 client.register_op(OperationIr::Init(desc));
28
29 out
30 }
31
32 fn float_random(
33 shape: Shape,
34 distribution: Distribution,
35 device: &Device<Self>,
36 dtype: FloatDType,
37 ) -> FloatTensor<Self> {
38 let client = get_client::<R>(device);
39 let dtype = dtype.into();
40 let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
41
42 client
43 .register(OperationIr::Float(dtype, FloatOperationIr::Random(desc)))
44 .output()
45 }
46
47 fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
48 let client = get_client::<R>(device);
49 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
50
51 client
52 .register(OperationIr::BaseFloat(BaseOperationIr::Zeros(desc)))
53 .output()
54 }
55
56 fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
57 let client = get_client::<R>(device);
58 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
59
60 client
61 .register(OperationIr::BaseFloat(BaseOperationIr::Ones(desc)))
62 .output()
63 }
64
65 fn float_full(
66 shape: Shape,
67 fill_value: Scalar,
68 device: &Device<Self>,
69 dtype: FloatDType,
70 ) -> FloatTensor<Self> {
71 let client = get_client::<R>(device);
72 let dtype = dtype.into();
73 let value = fill_value.into();
74 let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
75
76 client
77 .register(OperationIr::NumericFloat(
78 desc.out.dtype,
79 NumericOperationIr::Full(desc),
80 ))
81 .output()
82 }
83
84 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
85 Ok(tensor
86 .into_data()
87 .await?
88 .convert::<FloatElem<Self>>())
90 }
91
92 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
93 tensor.client.device()
94 }
95
96 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
97 if &tensor.client.device() == device {
98 return tensor;
99 }
100 R::change_client_backend(tensor, device)
101 }
102
103 fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
104 let client = tensor.client.clone();
105 let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
106 client.create_empty_handle()
107 });
108
109 client
110 .register(OperationIr::Float(
111 desc.input.dtype,
112 FloatOperationIr::IntoInt(desc),
113 ))
114 .output()
115 }
116
117 fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
118 let client = get_client::<R>(device);
119 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
120
121 client
122 .register(OperationIr::BaseFloat(BaseOperationIr::Empty(desc)))
123 .output()
124 }
125
126 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
127 let client = lhs.client.clone();
128 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
129 client.create_empty_handle()
130 });
131
132 client
133 .register(OperationIr::NumericFloat(
134 desc.out.dtype,
135 NumericOperationIr::Add(desc),
136 ))
137 .output()
138 }
139
140 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
141 let client = lhs.client.clone();
142 let rhs = rhs.into();
143 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
144
145 client
146 .register(OperationIr::NumericFloat(
147 desc.out.dtype,
148 NumericOperationIr::AddScalar(desc),
149 ))
150 .output()
151 }
152
153 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
154 let client = tensor.client.clone();
155 let min = min.into();
156 let max = max.into();
157 let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
158
159 client
160 .register(OperationIr::NumericFloat(
161 desc.out.dtype,
162 NumericOperationIr::Clamp(desc),
163 ))
164 .output()
165 }
166
167 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
168 let client = lhs.client.clone();
169 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
170 client.create_empty_handle()
171 });
172
173 client
174 .register(OperationIr::NumericFloat(
175 desc.out.dtype,
176 NumericOperationIr::Sub(desc),
177 ))
178 .output()
179 }
180
181 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
182 let client = lhs.client.clone();
183 let rhs = rhs.into();
184 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
185
186 client
187 .register(OperationIr::NumericFloat(
188 desc.out.dtype,
189 NumericOperationIr::SubScalar(desc),
190 ))
191 .output()
192 }
193
194 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
195 let client = lhs.client.clone();
196 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
197 client.create_empty_handle()
198 });
199
200 client
201 .register(OperationIr::NumericFloat(
202 desc.out.dtype,
203 NumericOperationIr::Mul(desc),
204 ))
205 .output()
206 }
207
208 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
209 let client = lhs.client.clone();
210 let rhs = rhs.into();
211 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
212
213 client
214 .register(OperationIr::NumericFloat(
215 desc.out.dtype,
216 NumericOperationIr::MulScalar(desc),
217 ))
218 .output()
219 }
220
221 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
222 let client = lhs.client.clone();
223 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
224 client.create_empty_handle()
225 });
226
227 client
228 .register(OperationIr::NumericFloat(
229 desc.out.dtype,
230 NumericOperationIr::Div(desc),
231 ))
232 .output()
233 }
234
235 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
236 let client = lhs.client.clone();
237 let rhs = rhs.into();
238 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
239
240 client
241 .register(OperationIr::NumericFloat(
242 desc.out.dtype,
243 NumericOperationIr::DivScalar(desc),
244 ))
245 .output()
246 }
247
248 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
249 let client = lhs.client.clone();
250 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
251 client.create_empty_handle()
252 });
253
254 client
255 .register(OperationIr::NumericFloat(
256 desc.out.dtype,
257 NumericOperationIr::Rem(desc),
258 ))
259 .output()
260 }
261
262 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
263 let client = lhs.client.clone();
264 let rhs = rhs.into();
265 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
266
267 client
268 .register(OperationIr::NumericFloat(
269 desc.out.dtype,
270 NumericOperationIr::RemScalar(desc),
271 ))
272 .output()
273 }
274
275 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
276 let client = lhs.client.clone();
277 let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
278 client.create_empty_handle()
279 });
280
281 client
282 .register(OperationIr::Float(
283 desc.out.dtype,
284 FloatOperationIr::Matmul(desc),
285 ))
286 .output()
287 }
288
289 fn float_cross(
290 lhs: FloatTensor<Self>,
291 rhs: FloatTensor<Self>,
292 dim: usize,
293 ) -> FloatTensor<Self> {
294 let client = lhs.client.clone();
295 let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {
296 client.create_empty_handle()
297 });
298
299 client
300 .register(OperationIr::Float(
301 desc.out.dtype,
302 FloatOperationIr::Cross(desc),
303 ))
304 .output()
305 }
306
307 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
308 let client = tensor.client.clone();
309 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
310 client.create_empty_handle()
311 });
312
313 client
314 .register(OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc)))
315 .output()
316 }
317
318 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
319 let client = tensor.client.clone();
320 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
321
322 client
323 .register(OperationIr::BaseFloat(BaseOperationIr::Reshape(desc)))
324 .output()
325 }
326
327 fn float_gather(
328 dim: usize,
329 tensor: FloatTensor<Self>,
330 indices: IntTensor<Self>,
331 ) -> FloatTensor<Self> {
332 let client = tensor.client.clone();
333 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
334 client.create_empty_handle()
335 });
336
337 client
338 .register(OperationIr::BaseFloat(BaseOperationIr::Gather(desc)))
339 .output()
340 }
341
342 fn float_scatter_add(
343 dim: usize,
344 tensor: FloatTensor<Self>,
345 indices: IntTensor<Self>,
346 value: FloatTensor<Self>,
347 ) -> FloatTensor<Self> {
348 let client = tensor.client.clone();
349 let desc = ScatterOpIr::create(
350 tensor.into_ir(),
351 dim,
352 indices.into_ir(),
353 value.into_ir(),
354 IndexingUpdateOp::Add,
355 || client.create_empty_handle(),
356 );
357
358 client
359 .register(OperationIr::BaseFloat(BaseOperationIr::Scatter(desc)))
360 .output()
361 }
362
363 fn float_scatter_nd(
364 data: FloatTensor<Self>,
365 indices: IntTensor<Self>,
366 values: FloatTensor<Self>,
367 reduction: IndexingUpdateOp,
368 ) -> FloatTensor<Self> {
369 let client = data.client.clone();
370 let desc = ScatterNdOpIr::create(
371 data.into_ir(),
372 indices.into_ir(),
373 values.into_ir(),
374 reduction,
375 || client.create_empty_handle(),
376 );
377
378 client
379 .register(OperationIr::BaseFloat(BaseOperationIr::ScatterNd(desc)))
380 .output()
381 }
382
383 fn float_gather_nd(data: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {
384 let client = data.client.clone();
385 let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
386 client.create_empty_handle()
387 });
388
389 client
390 .register(OperationIr::BaseFloat(BaseOperationIr::GatherNd(desc)))
391 .output()
392 }
393
394 fn float_select(
395 tensor: FloatTensor<Self>,
396 dim: usize,
397 indices: IntTensor<Self>,
398 ) -> FloatTensor<Self> {
399 let client = tensor.client.clone();
400 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
401 client.create_empty_handle()
402 });
403
404 client
405 .register(OperationIr::BaseFloat(BaseOperationIr::Select(desc)))
406 .output()
407 }
408
409 fn float_select_add(
410 tensor: FloatTensor<Self>,
411 dim: usize,
412 indices: IntTensor<Self>,
413 value: FloatTensor<Self>,
414 ) -> FloatTensor<Self> {
415 let client = tensor.client.clone();
416 let desc = SelectAssignOpIr::create(
417 tensor.into_ir(),
418 dim,
419 indices.into_ir(),
420 value.into_ir(),
421 IndexingUpdateOp::Add,
422 || client.create_empty_handle(),
423 );
424
425 client
426 .register(OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc)))
427 .output()
428 }
429
430 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
431 let client = tensor.client.clone();
432 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
433 client.create_empty_handle()
434 });
435
436 client
437 .register(OperationIr::BaseFloat(BaseOperationIr::Slice(desc)))
438 .output()
439 }
440
441 fn float_slice_assign(
442 tensor: FloatTensor<Self>,
443 slices: &[burn_backend::Slice],
444 value: FloatTensor<Self>,
445 ) -> FloatTensor<Self> {
446 let client = tensor.client.clone();
447 let desc =
448 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
449 client.create_empty_handle()
450 });
451
452 client
453 .register(OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc)))
454 .output()
455 }
456
457 fn float_mask_where(
458 tensor: FloatTensor<Self>,
459 mask: BoolTensor<Self>,
460 value: FloatTensor<Self>,
461 ) -> FloatTensor<Self> {
462 let client = tensor.client.clone();
463 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
464 client.create_empty_handle()
465 });
466
467 client
468 .register(OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc)))
469 .output()
470 }
471
472 fn float_mask_fill(
473 tensor: FloatTensor<Self>,
474 mask: BoolTensor<Self>,
475 value: Scalar,
476 ) -> FloatTensor<Self> {
477 let client = tensor.client.clone();
478 let value = value.into();
479 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
480 client.create_empty_handle()
481 });
482
483 client
484 .register(OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc)))
485 .output()
486 }
487
488 fn float_equal(
489 lhs: FloatTensor<Self>,
490 rhs: FloatTensor<Self>,
491 out_dtype: BoolDType,
492 ) -> BoolTensor<Self> {
493 let client = lhs.client.clone();
494 let desc =
495 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
496 client.create_empty_handle()
497 });
498
499 client
500 .register(OperationIr::BaseFloat(BaseOperationIr::Equal(desc)))
501 .output()
502 }
503
504 fn float_equal_elem(
505 lhs: FloatTensor<Self>,
506 rhs: Scalar,
507 out_dtype: BoolDType,
508 ) -> BoolTensor<Self> {
509 let client = lhs.client.clone();
510 let rhs = rhs.into();
511 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
512 client.create_empty_handle()
513 });
514
515 client
516 .register(OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc)))
517 .output()
518 }
519
520 fn float_greater(
521 lhs: FloatTensor<Self>,
522 rhs: FloatTensor<Self>,
523 out_dtype: BoolDType,
524 ) -> BoolTensor<Self> {
525 let client = lhs.client.clone();
526 let desc =
527 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
528 client.create_empty_handle()
529 });
530
531 client
532 .register(OperationIr::NumericFloat(
533 desc.lhs.dtype,
534 NumericOperationIr::Greater(desc),
535 ))
536 .output()
537 }
538
539 fn float_greater_elem(
540 lhs: FloatTensor<Self>,
541 rhs: Scalar,
542 out_dtype: BoolDType,
543 ) -> BoolTensor<Self> {
544 let client = lhs.client.clone();
545 let rhs = rhs.into();
546 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
547 client.create_empty_handle()
548 });
549
550 client
551 .register(OperationIr::NumericFloat(
552 desc.lhs.dtype,
553 NumericOperationIr::GreaterElem(desc),
554 ))
555 .output()
556 }
557
558 fn float_greater_equal(
559 lhs: FloatTensor<Self>,
560 rhs: FloatTensor<Self>,
561 out_dtype: BoolDType,
562 ) -> BoolTensor<Self> {
563 let client = lhs.client.clone();
564 let desc =
565 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
566 client.create_empty_handle()
567 });
568
569 client
570 .register(OperationIr::NumericFloat(
571 desc.lhs.dtype,
572 NumericOperationIr::GreaterEqual(desc),
573 ))
574 .output()
575 }
576
577 fn float_greater_equal_elem(
578 lhs: FloatTensor<Self>,
579 rhs: Scalar,
580 out_dtype: BoolDType,
581 ) -> BoolTensor<Self> {
582 let client = lhs.client.clone();
583 let rhs = rhs.into();
584 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
585 client.create_empty_handle()
586 });
587
588 client
589 .register(OperationIr::NumericFloat(
590 desc.lhs.dtype,
591 NumericOperationIr::GreaterEqualElem(desc),
592 ))
593 .output()
594 }
595
596 fn float_lower(
597 lhs: FloatTensor<Self>,
598 rhs: FloatTensor<Self>,
599 out_dtype: BoolDType,
600 ) -> BoolTensor<Self> {
601 let client = lhs.client.clone();
602 let desc =
603 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
604 client.create_empty_handle()
605 });
606
607 client
608 .register(OperationIr::NumericFloat(
609 desc.lhs.dtype,
610 NumericOperationIr::Lower(desc),
611 ))
612 .output()
613 }
614
615 fn float_lower_elem(
616 lhs: FloatTensor<Self>,
617 rhs: Scalar,
618 out_dtype: BoolDType,
619 ) -> BoolTensor<Self> {
620 let client = lhs.client.clone();
621 let rhs = rhs.into();
622 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
623 client.create_empty_handle()
624 });
625
626 client
627 .register(OperationIr::NumericFloat(
628 desc.lhs.dtype,
629 NumericOperationIr::LowerElem(desc),
630 ))
631 .output()
632 }
633
634 fn float_lower_equal(
635 lhs: FloatTensor<Self>,
636 rhs: FloatTensor<Self>,
637 out_dtype: BoolDType,
638 ) -> BoolTensor<Self> {
639 let client = lhs.client.clone();
640 let desc =
641 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
642 client.create_empty_handle()
643 });
644
645 client
646 .register(OperationIr::NumericFloat(
647 desc.lhs.dtype,
648 NumericOperationIr::LowerEqual(desc),
649 ))
650 .output()
651 }
652
653 fn float_lower_equal_elem(
654 lhs: FloatTensor<Self>,
655 rhs: Scalar,
656 out_dtype: BoolDType,
657 ) -> BoolTensor<Self> {
658 let client = lhs.client.clone();
659 let rhs = rhs.into();
660 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
661 client.create_empty_handle()
662 });
663
664 client
665 .register(OperationIr::NumericFloat(
666 desc.lhs.dtype,
667 NumericOperationIr::LowerEqualElem(desc),
668 ))
669 .output()
670 }
671
672 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
673 let client = tensor.client.clone();
674 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
675
676 client
677 .register(OperationIr::NumericFloat(
678 desc.out.dtype,
679 NumericOperationIr::Sum(desc),
680 ))
681 .output()
682 }
683
684 fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
685 let client = tensor.client.clone();
686 let desc =
687 ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
688
689 client
690 .register(OperationIr::NumericFloat(
691 desc.out.dtype,
692 NumericOperationIr::SumDim(desc),
693 ))
694 .output()
695 }
696
697 fn float_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
698 let client = tensor.client.clone();
699 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
700
701 client
702 .register(OperationIr::NumericFloat(
703 desc.out.dtype,
704 NumericOperationIr::Prod(desc),
705 ))
706 .output()
707 }
708
709 fn float_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
710 let client = tensor.client.clone();
711 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
712
713 client
714 .register(OperationIr::NumericFloat(
715 desc.out.dtype,
716 NumericOperationIr::ProdDim(desc),
717 ))
718 .output()
719 }
720
721 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
722 let client = tensor.client.clone();
723 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
724
725 client
726 .register(OperationIr::NumericFloat(
727 desc.out.dtype,
728 NumericOperationIr::Mean(desc),
729 ))
730 .output()
731 }
732
733 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
734 let client = tensor.client.clone();
735 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
736
737 client
738 .register(OperationIr::NumericFloat(
739 desc.out.dtype,
740 NumericOperationIr::MeanDim(desc),
741 ))
742 .output()
743 }
744
745 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
746 let client = tensor.client.clone();
747 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
748
749 client
750 .register(OperationIr::NumericFloat(
751 desc.out.dtype,
752 NumericOperationIr::CumSum(desc),
753 ))
754 .output()
755 }
756
757 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
758 let client = tensor.client.clone();
759 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
760
761 client
762 .register(OperationIr::NumericFloat(
763 desc.out.dtype,
764 NumericOperationIr::CumProd(desc),
765 ))
766 .output()
767 }
768
769 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
770 let client = tensor.client.clone();
771 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
772
773 client
774 .register(OperationIr::NumericFloat(
775 desc.out.dtype,
776 NumericOperationIr::CumMin(desc),
777 ))
778 .output()
779 }
780
781 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
782 let client = tensor.client.clone();
783 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
784
785 client
786 .register(OperationIr::NumericFloat(
787 desc.out.dtype,
788 NumericOperationIr::CumMax(desc),
789 ))
790 .output()
791 }
792
793 fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
794 let client = lhs.client.clone();
795 let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());
796
797 client
798 .register(OperationIr::Float(
799 desc.out.dtype,
800 FloatOperationIr::Exp(desc),
801 ))
802 .output()
803 }
804
805 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
806 let client = tensor.client.clone();
807 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
808
809 client
810 .register(OperationIr::Float(
811 desc.out.dtype,
812 FloatOperationIr::Log(desc),
813 ))
814 .output()
815 }
816
817 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
818 let client = tensor.client.clone();
819 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
820
821 client
822 .register(OperationIr::Float(
823 desc.out.dtype,
824 FloatOperationIr::Log1p(desc),
825 ))
826 .output()
827 }
828
829 fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
830 let client = lhs.client.clone();
831 let rhs = rhs.into();
832 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
833
834 client
835 .register(OperationIr::Float(
836 desc.out.dtype,
837 FloatOperationIr::PowfScalar(desc),
838 ))
839 .output()
840 }
841
842 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
843 let client = tensor.client.clone();
844 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
845
846 client
847 .register(OperationIr::Float(
848 desc.out.dtype,
849 FloatOperationIr::Sqrt(desc),
850 ))
851 .output()
852 }
853
854 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
855 let client = tensor.client.clone();
856 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
857
858 client
859 .register(OperationIr::NumericFloat(
860 desc.out.dtype,
861 NumericOperationIr::Abs(desc),
862 ))
863 .output()
864 }
865
866 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
867 let client = tensor.client.clone();
868 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
869
870 client
871 .register(OperationIr::Float(
872 desc.out.dtype,
873 FloatOperationIr::Cos(desc),
874 ))
875 .output()
876 }
877
878 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
879 let client = tensor.client.clone();
880 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
881
882 client
883 .register(OperationIr::Float(
884 desc.out.dtype,
885 FloatOperationIr::Cosh(desc),
886 ))
887 .output()
888 }
889
890 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
891 let client = tensor.client.clone();
892 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
893
894 client
895 .register(OperationIr::Float(
896 desc.out.dtype,
897 FloatOperationIr::Sin(desc),
898 ))
899 .output()
900 }
901
902 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
903 let client = tensor.client.clone();
904 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
905
906 client
907 .register(OperationIr::Float(
908 desc.out.dtype,
909 FloatOperationIr::Sinh(desc),
910 ))
911 .output()
912 }
913
914 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
915 let client = tensor.client.clone();
916 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
917
918 client
919 .register(OperationIr::Float(
920 desc.out.dtype,
921 FloatOperationIr::Tan(desc),
922 ))
923 .output()
924 }
925
926 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
927 let client = tensor.client.clone();
928 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
929
930 client
931 .register(OperationIr::Float(
932 desc.out.dtype,
933 FloatOperationIr::Tanh(desc),
934 ))
935 .output()
936 }
937
938 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
939 let client = tensor.client.clone();
940 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
941
942 client
943 .register(OperationIr::Float(
944 desc.out.dtype,
945 FloatOperationIr::ArcCos(desc),
946 ))
947 .output()
948 }
949
950 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
951 let client = tensor.client.clone();
952 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
953
954 client
955 .register(OperationIr::Float(
956 desc.out.dtype,
957 FloatOperationIr::ArcCosh(desc),
958 ))
959 .output()
960 }
961
962 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
963 let client = tensor.client.clone();
964 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
965
966 client
967 .register(OperationIr::Float(
968 desc.out.dtype,
969 FloatOperationIr::ArcSin(desc),
970 ))
971 .output()
972 }
973
974 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
975 let client = tensor.client.clone();
976 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
977
978 client
979 .register(OperationIr::Float(
980 desc.out.dtype,
981 FloatOperationIr::ArcSinh(desc),
982 ))
983 .output()
984 }
985
986 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
987 let client = tensor.client.clone();
988 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
989
990 client
991 .register(OperationIr::Float(
992 desc.out.dtype,
993 FloatOperationIr::ArcTan(desc),
994 ))
995 .output()
996 }
997
998 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
999 let client = tensor.client.clone();
1000 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1001
1002 client
1003 .register(OperationIr::Float(
1004 desc.out.dtype,
1005 FloatOperationIr::ArcTanh(desc),
1006 ))
1007 .output()
1008 }
1009
1010 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1011 let client = lhs.client.clone();
1012 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1013 client.create_empty_handle()
1014 });
1015
1016 client
1017 .register(OperationIr::Float(
1018 desc.out.dtype,
1019 FloatOperationIr::ArcTan2(desc),
1020 ))
1021 .output()
1022 }
1023
1024 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1025 let client = tensor.client.clone();
1026 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1027
1028 client
1029 .register(OperationIr::Float(
1030 desc.out.dtype,
1031 FloatOperationIr::Round(desc),
1032 ))
1033 .output()
1034 }
1035
1036 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1037 let client = tensor.client.clone();
1038 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1039
1040 client
1041 .register(OperationIr::Float(
1042 desc.out.dtype,
1043 FloatOperationIr::Floor(desc),
1044 ))
1045 .output()
1046 }
1047
1048 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1049 let client = tensor.client.clone();
1050 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1051
1052 client
1053 .register(OperationIr::Float(
1054 desc.out.dtype,
1055 FloatOperationIr::Ceil(desc),
1056 ))
1057 .output()
1058 }
1059
1060 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1061 let client = tensor.client.clone();
1062 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1063
1064 client
1065 .register(OperationIr::Float(
1066 desc.out.dtype,
1067 FloatOperationIr::Trunc(desc),
1068 ))
1069 .output()
1070 }
1071
1072 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1073 let client = tensor.client.clone();
1074 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1075
1076 client
1077 .register(OperationIr::Float(
1078 desc.out.dtype,
1079 FloatOperationIr::Recip(desc),
1080 ))
1081 .output()
1082 }
1083
1084 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1085 let client = tensor.client.clone();
1086 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1087
1088 client
1089 .register(OperationIr::Float(
1090 desc.out.dtype,
1091 FloatOperationIr::Erf(desc),
1092 ))
1093 .output()
1094 }
1095
1096 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1097 let client = tensors.first().unwrap().client.clone();
1098 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
1099 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
1100
1101 client
1102 .register(OperationIr::BaseFloat(BaseOperationIr::Cat(desc)))
1103 .output()
1104 }
1105
1106 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
1107 let client = tensor.client.clone();
1108 let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
1109 client.create_empty_handle()
1110 });
1111
1112 client
1113 .register(OperationIr::NumericFloat(
1114 desc.input.dtype,
1115 NumericOperationIr::ArgMax(desc),
1116 ))
1117 .output()
1118 }
1119
1120 fn float_argtopk(
1121 tensor: FloatTensor<Self>,
1122 dim: usize,
1123 k: usize,
1124 out_dtype: IntDType,
1125 ) -> IntTensor<Self> {
1126 let client = tensor.client.clone();
1127 let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, k, out_dtype.into(), || {
1128 client.create_empty_handle()
1129 });
1130
1131 client
1132 .register(OperationIr::NumericFloat(
1133 desc.input.dtype,
1134 NumericOperationIr::ArgTopK(desc),
1135 ))
1136 .output()
1137 }
1138
1139 fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> FloatTensor<Self> {
1140 let client = tensor.client.clone();
1141 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1142
1143 client
1144 .register(OperationIr::NumericFloat(
1145 desc.input.dtype,
1146 NumericOperationIr::TopK(desc),
1147 ))
1148 .output()
1149 }
1150
1151 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1152 let client = tensor.client.clone();
1153 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1154 client.create_empty_handle()
1155 });
1156
1157 client
1158 .register(OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc)))
1159 .output()
1160 }
1161
1162 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
1163 let client = tensor.client.clone();
1164 let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, 1, out_dtype.into(), || {
1165 client.create_empty_handle()
1166 });
1167
1168 client
1169 .register(OperationIr::NumericFloat(
1170 desc.input.dtype,
1171 NumericOperationIr::ArgMin(desc),
1172 ))
1173 .output()
1174 }
1175
1176 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1177 let client = tensor.client.clone();
1178 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1179
1180 client
1181 .register(OperationIr::NumericFloat(
1182 desc.out.dtype,
1183 NumericOperationIr::Max(desc),
1184 ))
1185 .output()
1186 }
1187
1188 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1189 let client = tensor.client.clone();
1190 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1191
1192 client
1193 .register(OperationIr::NumericFloat(
1194 desc.out.dtype,
1195 NumericOperationIr::MaxDim(desc),
1196 ))
1197 .output()
1198 }
1199
1200 fn float_max_dim_with_indices(
1201 tensor: FloatTensor<Self>,
1202 dim: usize,
1203 indices_dtype: IntDType,
1204 ) -> (FloatTensor<Self>, IntTensor<Self>) {
1205 let client = tensor.client.clone();
1206 let desc =
1207 ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
1208 client.create_empty_handle()
1209 });
1210
1211 client
1212 .register(OperationIr::NumericFloat(
1213 desc.tensor.dtype,
1214 NumericOperationIr::MaxDimWithIndices(desc),
1215 ))
1216 .outputs()
1217 .into()
1218 }
1219
1220 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1221 let client = tensor.client.clone();
1222 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1223
1224 client
1225 .register(OperationIr::NumericFloat(
1226 desc.out.dtype,
1227 NumericOperationIr::Min(desc),
1228 ))
1229 .output()
1230 }
1231
1232 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1233 let client = tensor.client.clone();
1234 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1235
1236 client
1237 .register(OperationIr::NumericFloat(
1238 desc.out.dtype,
1239 NumericOperationIr::MinDim(desc),
1240 ))
1241 .output()
1242 }
1243
1244 fn float_min_dim_with_indices(
1245 tensor: FloatTensor<Self>,
1246 dim: usize,
1247 indices_dtype: IntDType,
1248 ) -> (FloatTensor<Self>, IntTensor<Self>) {
1249 let client = tensor.client.clone();
1250 let desc =
1251 ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, indices_dtype.into(), || {
1252 client.create_empty_handle()
1253 });
1254
1255 client
1256 .register(OperationIr::NumericFloat(
1257 desc.tensor.dtype,
1258 NumericOperationIr::MinDimWithIndices(desc),
1259 ))
1260 .outputs()
1261 .into()
1262 }
1263
1264 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1265 let client = lhs.client.clone();
1266 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1267 client.create_empty_handle()
1268 });
1269
1270 client
1271 .register(OperationIr::Float(
1272 desc.out.dtype,
1273 FloatOperationIr::Powf(desc),
1274 ))
1275 .output()
1276 }
1277
1278 fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {
1279 let client = lhs.client.clone();
1280 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1281 client.create_empty_handle()
1282 });
1283
1284 client
1285 .register(OperationIr::NumericFloat(
1286 desc.out.dtype,
1287 NumericOperationIr::Powi(desc),
1288 ))
1289 .output()
1290 }
1291
1292 fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
1293 let client = lhs.client.clone();
1294 let rhs = rhs.into();
1295 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1296
1297 client
1298 .register(OperationIr::NumericFloat(
1299 desc.out.dtype,
1300 NumericOperationIr::PowiScalar(desc),
1301 ))
1302 .output()
1303 }
1304
1305 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1306 let client = tensor.client.clone();
1307 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
1308 client.create_empty_handle()
1309 });
1310
1311 client
1312 .register(OperationIr::BaseFloat(BaseOperationIr::Permute(desc)))
1313 .output()
1314 }
1315
1316 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
1317 let client = tensor.client.clone();
1318 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
1319
1320 client
1321 .register(OperationIr::BaseFloat(BaseOperationIr::Expand(desc)))
1322 .output()
1323 }
1324
1325 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1326 let client = tensor.client.clone();
1327 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
1328 client.create_empty_handle()
1329 });
1330
1331 client
1332 .register(OperationIr::BaseFloat(BaseOperationIr::Flip(desc)))
1333 .output()
1334 }
1335
1336 fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {
1337 let client = tensor.client.clone();
1338 let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1339 client.create_empty_handle()
1340 });
1341
1342 client
1343 .register(OperationIr::BaseFloat(BaseOperationIr::Cast(desc)))
1344 .output()
1345 }
1346
1347 fn float_unfold(
1348 tensor: FloatTensor<Self>,
1349 dim: usize,
1350 size: usize,
1351 step: usize,
1352 ) -> FloatTensor<Self> {
1353 let client = tensor.client.clone();
1354 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1355 client.create_empty_handle()
1356 });
1357
1358 client
1359 .register(OperationIr::BaseFloat(BaseOperationIr::Unfold(desc)))
1360 .output()
1361 }
1362}