1use alloc::vec::Vec;
2use burn_backend::backend::{Backend, ExecutionError};
3
4use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
5use burn_backend::tensor::{
6 BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
7};
8use burn_backend::{
9 Distribution, Element, FloatDType, Shape, Slice, TensorData, ops::FloatTensorOps,
10};
11use burn_ir::{
12 BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, CrossOpIr, DimOpIr,
13 FlipOpIr, FloatOperationIr, FullOpIr, GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr,
14 MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr,
15 ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarIr, ScalarOpIr,
16 ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr,
17 UnaryOpIr, UnfoldOpIr,
18};
19
20impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
21 fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
22 let client = get_client::<R>(device);
23 let out = client.register_tensor_data(data);
24 let desc = InitOperationIr {
25 out: out.to_ir_out(),
26 };
27
28 client.register_op(OperationIr::Init(desc));
30
31 out
32 }
33
34 fn float_random(
35 shape: Shape,
36 distribution: Distribution,
37 device: &Device<Self>,
38 ) -> FloatTensor<Self> {
39 let client = get_client::<R>(device);
40 let dtype = FloatElem::<Self>::dtype();
41 let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
42
43 client
44 .register(OperationIr::Float(dtype, FloatOperationIr::Random(desc)))
45 .output()
46 }
47
48 fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
49 let client = get_client::<R>(device);
50 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
51
52 client
53 .register(OperationIr::BaseFloat(BaseOperationIr::Zeros(desc)))
54 .output()
55 }
56
57 fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
58 let client = get_client::<R>(device);
59 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
60
61 client
62 .register(OperationIr::BaseFloat(BaseOperationIr::Ones(desc)))
63 .output()
64 }
65
66 fn float_full(
67 shape: Shape,
68 fill_value: FloatElem<Self>,
69 device: &Device<Self>,
70 dtype: FloatDType,
71 ) -> FloatTensor<Self> {
72 let client = get_client::<R>(device);
73 let dtype = dtype.into();
74 let value = ScalarIr::with_dtype(fill_value, &dtype);
75 let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
76
77 client
78 .register(OperationIr::NumericFloat(
79 desc.out.dtype,
80 NumericOperationIr::Full(desc),
81 ))
82 .output()
83 }
84
85 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
86 Ok(tensor
87 .into_data()
88 .await?
89 .convert::<<Self as Backend>::FloatElem>())
91 }
92
93 fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {
94 tensor.client.device()
95 }
96
97 fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {
98 if &tensor.client.device() == device {
99 return tensor;
100 }
101 R::change_client_backend(tensor, device)
102 }
103
104 fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {
105 let client = tensor.client.clone();
106 let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {
107 client.create_empty_handle()
108 });
109
110 client
111 .register(OperationIr::Float(
112 desc.input.dtype,
113 FloatOperationIr::IntoInt(desc),
114 ))
115 .output()
116 }
117
118 fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {
119 let client = get_client::<R>(device);
120 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
121
122 client
123 .register(OperationIr::BaseFloat(BaseOperationIr::Empty(desc)))
124 .output()
125 }
126
127 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
128 let client = lhs.client.clone();
129 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
130 client.create_empty_handle()
131 });
132
133 client
134 .register(OperationIr::NumericFloat(
135 desc.out.dtype,
136 NumericOperationIr::Add(desc),
137 ))
138 .output()
139 }
140
141 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
142 let client = lhs.client.clone();
143 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
144 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
145
146 client
147 .register(OperationIr::NumericFloat(
148 desc.out.dtype,
149 NumericOperationIr::AddScalar(desc),
150 ))
151 .output()
152 }
153
154 fn float_clamp(
155 tensor: FloatTensor<Self>,
156 min: FloatElem<Self>,
157 max: FloatElem<Self>,
158 ) -> FloatTensor<Self> {
159 let client = tensor.client.clone();
160 let min = ScalarIr::with_dtype(min, &tensor.dtype);
161 let max = ScalarIr::with_dtype(max, &tensor.dtype);
162 let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
163
164 client
165 .register(OperationIr::NumericFloat(
166 desc.out.dtype,
167 NumericOperationIr::Clamp(desc),
168 ))
169 .output()
170 }
171
172 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
173 let client = lhs.client.clone();
174 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
175 client.create_empty_handle()
176 });
177
178 client
179 .register(OperationIr::NumericFloat(
180 desc.out.dtype,
181 NumericOperationIr::Sub(desc),
182 ))
183 .output()
184 }
185
186 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
187 let client = lhs.client.clone();
188 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
189 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
190
191 client
192 .register(OperationIr::NumericFloat(
193 desc.out.dtype,
194 NumericOperationIr::SubScalar(desc),
195 ))
196 .output()
197 }
198
199 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
200 let client = lhs.client.clone();
201 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
202 client.create_empty_handle()
203 });
204
205 client
206 .register(OperationIr::NumericFloat(
207 desc.out.dtype,
208 NumericOperationIr::Mul(desc),
209 ))
210 .output()
211 }
212
213 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
214 let client = lhs.client.clone();
215 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
216 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
217
218 client
219 .register(OperationIr::NumericFloat(
220 desc.out.dtype,
221 NumericOperationIr::MulScalar(desc),
222 ))
223 .output()
224 }
225
226 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
227 let client = lhs.client.clone();
228 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
229 client.create_empty_handle()
230 });
231
232 client
233 .register(OperationIr::NumericFloat(
234 desc.out.dtype,
235 NumericOperationIr::Div(desc),
236 ))
237 .output()
238 }
239
240 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
241 let client = lhs.client.clone();
242 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
243 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
244
245 client
246 .register(OperationIr::NumericFloat(
247 desc.out.dtype,
248 NumericOperationIr::DivScalar(desc),
249 ))
250 .output()
251 }
252
253 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
254 let client = lhs.client.clone();
255 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
256 client.create_empty_handle()
257 });
258
259 client
260 .register(OperationIr::NumericFloat(
261 desc.out.dtype,
262 NumericOperationIr::Rem(desc),
263 ))
264 .output()
265 }
266
267 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> FloatTensor<Self> {
268 let client = lhs.client.clone();
269 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
270 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
271
272 client
273 .register(OperationIr::NumericFloat(
274 desc.out.dtype,
275 NumericOperationIr::RemScalar(desc),
276 ))
277 .output()
278 }
279
280 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
281 let client = lhs.client.clone();
282 let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
283 client.create_empty_handle()
284 });
285
286 client
287 .register(OperationIr::Float(
288 desc.out.dtype,
289 FloatOperationIr::Matmul(desc),
290 ))
291 .output()
292 }
293
294 fn float_cross(
295 lhs: FloatTensor<Self>,
296 rhs: FloatTensor<Self>,
297 dim: usize,
298 ) -> FloatTensor<Self> {
299 let client = lhs.client.clone();
300 let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {
301 client.create_empty_handle()
302 });
303
304 client
305 .register(OperationIr::Float(
306 desc.out.dtype,
307 FloatOperationIr::Cross(desc),
308 ))
309 .output()
310 }
311
312 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
313 let client = tensor.client.clone();
314 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
315 client.create_empty_handle()
316 });
317
318 client
319 .register(OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc)))
320 .output()
321 }
322
323 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
324 let client = tensor.client.clone();
325 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
326
327 client
328 .register(OperationIr::BaseFloat(BaseOperationIr::Reshape(desc)))
329 .output()
330 }
331
332 fn float_gather(
333 dim: usize,
334 tensor: FloatTensor<Self>,
335 indices: IntTensor<Self>,
336 ) -> FloatTensor<Self> {
337 let client = tensor.client.clone();
338 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
339 client.create_empty_handle()
340 });
341
342 client
343 .register(OperationIr::BaseFloat(BaseOperationIr::Gather(desc)))
344 .output()
345 }
346
347 fn float_scatter_add(
348 dim: usize,
349 tensor: FloatTensor<Self>,
350 indices: IntTensor<Self>,
351 value: FloatTensor<Self>,
352 ) -> FloatTensor<Self> {
353 let client = tensor.client.clone();
354 let desc = ScatterOpIr::create(
355 tensor.into_ir(),
356 dim,
357 indices.into_ir(),
358 value.into_ir(),
359 IndexingUpdateOp::Add,
360 || client.create_empty_handle(),
361 );
362
363 client
364 .register(OperationIr::BaseFloat(BaseOperationIr::Scatter(desc)))
365 .output()
366 }
367
368 fn float_select(
369 tensor: FloatTensor<Self>,
370 dim: usize,
371 indices: IntTensor<Self>,
372 ) -> FloatTensor<Self> {
373 let client = tensor.client.clone();
374 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
375 client.create_empty_handle()
376 });
377
378 client
379 .register(OperationIr::BaseFloat(BaseOperationIr::Select(desc)))
380 .output()
381 }
382
383 fn float_select_add(
384 tensor: FloatTensor<Self>,
385 dim: usize,
386 indices: IntTensor<Self>,
387 value: FloatTensor<Self>,
388 ) -> FloatTensor<Self> {
389 let client = tensor.client.clone();
390 let desc = SelectAssignOpIr::create(
391 tensor.into_ir(),
392 dim,
393 indices.into_ir(),
394 value.into_ir(),
395 IndexingUpdateOp::Add,
396 || client.create_empty_handle(),
397 );
398
399 client
400 .register(OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc)))
401 .output()
402 }
403
404 fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {
405 let client = tensor.client.clone();
406 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
407 client.create_empty_handle()
408 });
409
410 client
411 .register(OperationIr::BaseFloat(BaseOperationIr::Slice(desc)))
412 .output()
413 }
414
415 fn float_slice_assign(
416 tensor: FloatTensor<Self>,
417 slices: &[burn_backend::Slice],
418 value: FloatTensor<Self>,
419 ) -> FloatTensor<Self> {
420 let client = tensor.client.clone();
421 let desc =
422 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
423 client.create_empty_handle()
424 });
425
426 client
427 .register(OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc)))
428 .output()
429 }
430
431 fn float_mask_where(
432 tensor: FloatTensor<Self>,
433 mask: BoolTensor<Self>,
434 value: FloatTensor<Self>,
435 ) -> FloatTensor<Self> {
436 let client = tensor.client.clone();
437 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
438 client.create_empty_handle()
439 });
440
441 client
442 .register(OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc)))
443 .output()
444 }
445
446 fn float_mask_fill(
447 tensor: FloatTensor<Self>,
448 mask: BoolTensor<Self>,
449 value: FloatElem<Self>,
450 ) -> FloatTensor<Self> {
451 let client = tensor.client.clone();
452 let value = ScalarIr::with_dtype(value, &tensor.dtype);
453 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
454 client.create_empty_handle()
455 });
456
457 client
458 .register(OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc)))
459 .output()
460 }
461
462 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
463 let client = lhs.client.clone();
464 let desc = BinaryOpIr::create_comparison(
465 lhs.into_ir(),
466 rhs.into_ir(),
467 R::BoolElem::dtype(),
468 || client.create_empty_handle(),
469 );
470
471 client
472 .register(OperationIr::BaseFloat(BaseOperationIr::Equal(desc)))
473 .output()
474 }
475
476 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
477 let client = lhs.client.clone();
478 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
479 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
480 client.create_empty_handle()
481 });
482
483 client
484 .register(OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc)))
485 .output()
486 }
487
488 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
489 let client = lhs.client.clone();
490 let desc = BinaryOpIr::create_comparison(
491 lhs.into_ir(),
492 rhs.into_ir(),
493 R::BoolElem::dtype(),
494 || client.create_empty_handle(),
495 );
496
497 client
498 .register(OperationIr::NumericFloat(
499 desc.lhs.dtype,
500 NumericOperationIr::Greater(desc),
501 ))
502 .output()
503 }
504
505 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
506 let client = lhs.client.clone();
507 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
508 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
509 client.create_empty_handle()
510 });
511
512 client
513 .register(OperationIr::NumericFloat(
514 desc.lhs.dtype,
515 NumericOperationIr::GreaterElem(desc),
516 ))
517 .output()
518 }
519
520 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
521 let client = lhs.client.clone();
522 let desc = BinaryOpIr::create_comparison(
523 lhs.into_ir(),
524 rhs.into_ir(),
525 R::BoolElem::dtype(),
526 || client.create_empty_handle(),
527 );
528
529 client
530 .register(OperationIr::NumericFloat(
531 desc.lhs.dtype,
532 NumericOperationIr::GreaterEqual(desc),
533 ))
534 .output()
535 }
536
537 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
538 let client = lhs.client.clone();
539 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
540 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
541 client.create_empty_handle()
542 });
543
544 client
545 .register(OperationIr::NumericFloat(
546 desc.lhs.dtype,
547 NumericOperationIr::GreaterEqualElem(desc),
548 ))
549 .output()
550 }
551
552 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
553 let client = lhs.client.clone();
554 let desc = BinaryOpIr::create_comparison(
555 lhs.into_ir(),
556 rhs.into_ir(),
557 R::BoolElem::dtype(),
558 || client.create_empty_handle(),
559 );
560
561 client
562 .register(OperationIr::NumericFloat(
563 desc.lhs.dtype,
564 NumericOperationIr::Lower(desc),
565 ))
566 .output()
567 }
568
569 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
570 let client = lhs.client.clone();
571 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
572 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
573 client.create_empty_handle()
574 });
575
576 client
577 .register(OperationIr::NumericFloat(
578 desc.lhs.dtype,
579 NumericOperationIr::LowerElem(desc),
580 ))
581 .output()
582 }
583
584 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {
585 let client = lhs.client.clone();
586 let desc = BinaryOpIr::create_comparison(
587 lhs.into_ir(),
588 rhs.into_ir(),
589 R::BoolElem::dtype(),
590 || client.create_empty_handle(),
591 );
592
593 client
594 .register(OperationIr::NumericFloat(
595 desc.lhs.dtype,
596 NumericOperationIr::LowerEqual(desc),
597 ))
598 .output()
599 }
600
601 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: FloatElem<Self>) -> BoolTensor<Self> {
602 let client = lhs.client.clone();
603 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
604 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
605 client.create_empty_handle()
606 });
607
608 client
609 .register(OperationIr::NumericFloat(
610 desc.lhs.dtype,
611 NumericOperationIr::LowerEqualElem(desc),
612 ))
613 .output()
614 }
615
616 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
617 let client = tensor.client.clone();
618 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
619
620 client
621 .register(OperationIr::NumericFloat(
622 desc.out.dtype,
623 NumericOperationIr::Sum(desc),
624 ))
625 .output()
626 }
627
628 fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {
629 let client = tensor.client.clone();
630 let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());
631
632 client
633 .register(OperationIr::NumericFloat(
634 desc.out.dtype,
635 NumericOperationIr::SumDim(desc),
636 ))
637 .output()
638 }
639
640 fn float_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
641 let client = tensor.client.clone();
642 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
643
644 client
645 .register(OperationIr::NumericFloat(
646 desc.out.dtype,
647 NumericOperationIr::Prod(desc),
648 ))
649 .output()
650 }
651
652 fn float_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
653 let client = tensor.client.clone();
654 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
655
656 client
657 .register(OperationIr::NumericFloat(
658 desc.out.dtype,
659 NumericOperationIr::ProdDim(desc),
660 ))
661 .output()
662 }
663
664 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
665 let client = tensor.client.clone();
666 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
667
668 client
669 .register(OperationIr::NumericFloat(
670 desc.out.dtype,
671 NumericOperationIr::Mean(desc),
672 ))
673 .output()
674 }
675
676 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
677 let client = tensor.client.clone();
678 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
679
680 client
681 .register(OperationIr::NumericFloat(
682 desc.out.dtype,
683 NumericOperationIr::MeanDim(desc),
684 ))
685 .output()
686 }
687
688 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
689 let client = tensor.client.clone();
690 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
691
692 client
693 .register(OperationIr::NumericFloat(
694 desc.out.dtype,
695 NumericOperationIr::CumSum(desc),
696 ))
697 .output()
698 }
699
700 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
701 let client = tensor.client.clone();
702 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
703
704 client
705 .register(OperationIr::NumericFloat(
706 desc.out.dtype,
707 NumericOperationIr::CumProd(desc),
708 ))
709 .output()
710 }
711
712 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
713 let client = tensor.client.clone();
714 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
715
716 client
717 .register(OperationIr::NumericFloat(
718 desc.out.dtype,
719 NumericOperationIr::CumMin(desc),
720 ))
721 .output()
722 }
723
724 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
725 let client = tensor.client.clone();
726 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
727
728 client
729 .register(OperationIr::NumericFloat(
730 desc.out.dtype,
731 NumericOperationIr::CumMax(desc),
732 ))
733 .output()
734 }
735
736 fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
737 let client = lhs.client.clone();
738 let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());
739
740 client
741 .register(OperationIr::Float(
742 desc.out.dtype,
743 FloatOperationIr::Exp(desc),
744 ))
745 .output()
746 }
747
748 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
749 let client = tensor.client.clone();
750 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
751
752 client
753 .register(OperationIr::Float(
754 desc.out.dtype,
755 FloatOperationIr::Log(desc),
756 ))
757 .output()
758 }
759
760 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
761 let client = tensor.client.clone();
762 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
763
764 client
765 .register(OperationIr::Float(
766 desc.out.dtype,
767 FloatOperationIr::Log1p(desc),
768 ))
769 .output()
770 }
771
772 fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: f32) -> FloatTensor<Self> {
773 let client = lhs.client.clone();
774 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
775 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
776
777 client
778 .register(OperationIr::Float(
779 desc.out.dtype,
780 FloatOperationIr::PowfScalar(desc),
781 ))
782 .output()
783 }
784
785 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
786 let client = tensor.client.clone();
787 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
788
789 client
790 .register(OperationIr::Float(
791 desc.out.dtype,
792 FloatOperationIr::Sqrt(desc),
793 ))
794 .output()
795 }
796
797 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
798 let client = tensor.client.clone();
799 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
800
801 client
802 .register(OperationIr::NumericFloat(
803 desc.out.dtype,
804 NumericOperationIr::Abs(desc),
805 ))
806 .output()
807 }
808
809 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
810 let client = tensor.client.clone();
811 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
812
813 client
814 .register(OperationIr::Float(
815 desc.out.dtype,
816 FloatOperationIr::Cos(desc),
817 ))
818 .output()
819 }
820
821 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
822 let client = tensor.client.clone();
823 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
824
825 client
826 .register(OperationIr::Float(
827 desc.out.dtype,
828 FloatOperationIr::Cosh(desc),
829 ))
830 .output()
831 }
832
833 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
834 let client = tensor.client.clone();
835 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
836
837 client
838 .register(OperationIr::Float(
839 desc.out.dtype,
840 FloatOperationIr::Sin(desc),
841 ))
842 .output()
843 }
844
845 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
846 let client = tensor.client.clone();
847 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
848
849 client
850 .register(OperationIr::Float(
851 desc.out.dtype,
852 FloatOperationIr::Sinh(desc),
853 ))
854 .output()
855 }
856
857 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
858 let client = tensor.client.clone();
859 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
860
861 client
862 .register(OperationIr::Float(
863 desc.out.dtype,
864 FloatOperationIr::Tan(desc),
865 ))
866 .output()
867 }
868
869 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
870 let client = tensor.client.clone();
871 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
872
873 client
874 .register(OperationIr::Float(
875 desc.out.dtype,
876 FloatOperationIr::Tanh(desc),
877 ))
878 .output()
879 }
880
881 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
882 let client = tensor.client.clone();
883 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
884
885 client
886 .register(OperationIr::Float(
887 desc.out.dtype,
888 FloatOperationIr::ArcCos(desc),
889 ))
890 .output()
891 }
892
893 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
894 let client = tensor.client.clone();
895 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
896
897 client
898 .register(OperationIr::Float(
899 desc.out.dtype,
900 FloatOperationIr::ArcCosh(desc),
901 ))
902 .output()
903 }
904
905 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
906 let client = tensor.client.clone();
907 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
908
909 client
910 .register(OperationIr::Float(
911 desc.out.dtype,
912 FloatOperationIr::ArcSin(desc),
913 ))
914 .output()
915 }
916
917 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
918 let client = tensor.client.clone();
919 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
920
921 client
922 .register(OperationIr::Float(
923 desc.out.dtype,
924 FloatOperationIr::ArcSinh(desc),
925 ))
926 .output()
927 }
928
929 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
930 let client = tensor.client.clone();
931 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
932
933 client
934 .register(OperationIr::Float(
935 desc.out.dtype,
936 FloatOperationIr::ArcTan(desc),
937 ))
938 .output()
939 }
940
941 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
942 let client = tensor.client.clone();
943 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
944
945 client
946 .register(OperationIr::Float(
947 desc.out.dtype,
948 FloatOperationIr::ArcTanh(desc),
949 ))
950 .output()
951 }
952
953 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
954 let client = lhs.client.clone();
955 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
956 client.create_empty_handle()
957 });
958
959 client
960 .register(OperationIr::Float(
961 desc.out.dtype,
962 FloatOperationIr::ArcTan2(desc),
963 ))
964 .output()
965 }
966
967 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
968 let client = tensor.client.clone();
969 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
970
971 client
972 .register(OperationIr::Float(
973 desc.out.dtype,
974 FloatOperationIr::Round(desc),
975 ))
976 .output()
977 }
978
979 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
980 let client = tensor.client.clone();
981 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
982
983 client
984 .register(OperationIr::Float(
985 desc.out.dtype,
986 FloatOperationIr::Floor(desc),
987 ))
988 .output()
989 }
990
991 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
992 let client = tensor.client.clone();
993 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
994
995 client
996 .register(OperationIr::Float(
997 desc.out.dtype,
998 FloatOperationIr::Ceil(desc),
999 ))
1000 .output()
1001 }
1002
1003 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1004 let client = tensor.client.clone();
1005 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1006
1007 client
1008 .register(OperationIr::Float(
1009 desc.out.dtype,
1010 FloatOperationIr::Trunc(desc),
1011 ))
1012 .output()
1013 }
1014
1015 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1016 let client = tensor.client.clone();
1017 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1018
1019 client
1020 .register(OperationIr::Float(
1021 desc.out.dtype,
1022 FloatOperationIr::Recip(desc),
1023 ))
1024 .output()
1025 }
1026
1027 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1028 let client = tensor.client.clone();
1029 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1030
1031 client
1032 .register(OperationIr::Float(
1033 desc.out.dtype,
1034 FloatOperationIr::Erf(desc),
1035 ))
1036 .output()
1037 }
1038
1039 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
1040 let client = tensors.first().unwrap().client.clone();
1041 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
1042 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
1043
1044 client
1045 .register(OperationIr::BaseFloat(BaseOperationIr::Cat(desc)))
1046 .output()
1047 }
1048
1049 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1050 let client = tensor.client.clone();
1051 let desc =
1052 ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::<Self>::dtype(), || {
1053 client.create_empty_handle()
1054 });
1055
1056 client
1057 .register(OperationIr::NumericFloat(
1058 desc.input.dtype,
1059 NumericOperationIr::ArgMax(desc),
1060 ))
1061 .output()
1062 }
1063
1064 fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {
1065 let client = tensor.client.clone();
1066 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1067 client.create_empty_handle()
1068 });
1069
1070 client
1071 .register(OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc)))
1072 .output()
1073 }
1074
1075 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {
1076 let client = tensor.client.clone();
1077 let desc =
1078 ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::<Self>::dtype(), || {
1079 client.create_empty_handle()
1080 });
1081
1082 client
1083 .register(OperationIr::NumericFloat(
1084 desc.input.dtype,
1085 NumericOperationIr::ArgMin(desc),
1086 ))
1087 .output()
1088 }
1089
1090 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1091 let client = tensor.client.clone();
1092 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1093
1094 client
1095 .register(OperationIr::NumericFloat(
1096 desc.out.dtype,
1097 NumericOperationIr::Max(desc),
1098 ))
1099 .output()
1100 }
1101
1102 fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1103 let client = tensor.client.clone();
1104 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1105
1106 client
1107 .register(OperationIr::NumericFloat(
1108 desc.out.dtype,
1109 NumericOperationIr::MaxDim(desc),
1110 ))
1111 .output()
1112 }
1113
1114 fn float_max_dim_with_indices(
1115 tensor: FloatTensor<Self>,
1116 dim: usize,
1117 ) -> (FloatTensor<Self>, IntTensor<Self>) {
1118 let client = tensor.client.clone();
1119 let desc = ReduceDimWithIndicesOpIr::create(
1120 tensor.into_ir(),
1121 dim,
1122 IntElem::<Self>::dtype(),
1123 || client.create_empty_handle(),
1124 );
1125
1126 client
1127 .register(OperationIr::NumericFloat(
1128 desc.tensor.dtype,
1129 NumericOperationIr::MaxDimWithIndices(desc),
1130 ))
1131 .outputs()
1132 .into()
1133 }
1134
1135 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
1136 let client = tensor.client.clone();
1137 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1138
1139 client
1140 .register(OperationIr::NumericFloat(
1141 desc.out.dtype,
1142 NumericOperationIr::Min(desc),
1143 ))
1144 .output()
1145 }
1146
1147 fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1148 let client = tensor.client.clone();
1149 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1150
1151 client
1152 .register(OperationIr::NumericFloat(
1153 desc.out.dtype,
1154 NumericOperationIr::MinDim(desc),
1155 ))
1156 .output()
1157 }
1158
1159 fn float_min_dim_with_indices(
1160 tensor: FloatTensor<Self>,
1161 dim: usize,
1162 ) -> (FloatTensor<Self>, IntTensor<Self>) {
1163 let client = tensor.client.clone();
1164 let desc = ReduceDimWithIndicesOpIr::create(
1165 tensor.into_ir(),
1166 dim,
1167 IntElem::<Self>::dtype(),
1168 || client.create_empty_handle(),
1169 );
1170
1171 client
1172 .register(OperationIr::NumericFloat(
1173 desc.tensor.dtype,
1174 NumericOperationIr::MinDimWithIndices(desc),
1175 ))
1176 .outputs()
1177 .into()
1178 }
1179
1180 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
1181 let client = lhs.client.clone();
1182 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1183 client.create_empty_handle()
1184 });
1185
1186 client
1187 .register(OperationIr::NumericFloat(
1188 desc.out.dtype,
1189 NumericOperationIr::Powf(desc),
1190 ))
1191 .output()
1192 }
1193
1194 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1195 let client = tensor.client.clone();
1196 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
1197 client.create_empty_handle()
1198 });
1199
1200 client
1201 .register(OperationIr::BaseFloat(BaseOperationIr::Permute(desc)))
1202 .output()
1203 }
1204
1205 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
1206 let client = tensor.client.clone();
1207 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
1208
1209 client
1210 .register(OperationIr::BaseFloat(BaseOperationIr::Expand(desc)))
1211 .output()
1212 }
1213
1214 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
1215 let client = tensor.client.clone();
1216 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
1217 client.create_empty_handle()
1218 });
1219
1220 client
1221 .register(OperationIr::BaseFloat(BaseOperationIr::Flip(desc)))
1222 .output()
1223 }
1224
1225 fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {
1226 let client = tensor.client.clone();
1227 let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1228 client.create_empty_handle()
1229 });
1230
1231 client
1232 .register(OperationIr::BaseFloat(BaseOperationIr::Cast(desc)))
1233 .output()
1234 }
1235
1236 fn float_unfold(
1237 tensor: FloatTensor<Self>,
1238 dim: usize,
1239 size: usize,
1240 step: usize,
1241 ) -> FloatTensor<Self> {
1242 let client = tensor.client.clone();
1243 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1244 client.create_empty_handle()
1245 });
1246
1247 client
1248 .register(OperationIr::BaseFloat(BaseOperationIr::Unfold(desc)))
1249 .output()
1250 }
1251}