1use super::NoOp;
2use crate::{
3 Fusion, FusionBackend, binary_int_cmp_ops, binary_int_ops, get_client, reduce_int_ops,
4 scalar_int_cmp_ops, scalar_int_ops,
5 stream::{OperationStreams, StreamId, execution::Operation},
6 unary_int_ops,
7};
8use burn_ir::*;
9use burn_tensor::ops::unfold::calculate_unfold_shape;
10use burn_tensor::{
11 Device, Distribution, Element, IntDType, Shape, Slice, TensorData, TensorMetadata,
12 ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
13};
14use std::marker::PhantomData;
15
16impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
17 fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
18 #[derive(new, Debug)]
19 struct EmptyOps<B: FusionBackend> {
20 desc: TensorIr,
21 device: Device<B>,
22 }
23
24 impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
25 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
26 let output = B::int_empty(
27 self.desc.shape.clone(),
28 &self.device,
29 self.desc.dtype.into(),
30 );
31 handles.register_int_tensor::<B>(&self.desc.id, output);
32 }
33 }
34
35 let client = get_client::<B>(&device.clone());
36 let out = client.tensor_uninitialized(shape, dtype.into());
37
38 let desc = out.to_ir_out();
39
40 client.register(
41 OperationStreams::default(),
42 OperationIr::BaseInt(BaseOperationIr::Empty(desc.clone())),
43 EmptyOps::<B>::new(desc, device.clone()),
44 );
45
46 out
47 }
48
49 async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
50 tensor.int_into_data::<B>().await
51 }
52
53 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
54 let stream = StreamId::current();
55 let client = get_client::<B>(&device.clone());
56 let dtype = data.dtype;
57 let tensor = B::int_from_data(data, device);
58 let shape = tensor.shape();
59
60 let handle = B::int_tensor_handle(tensor);
61 let out = client.register_tensor(handle, shape, stream, dtype);
62 let desc = out.to_ir_out();
63
64 client.register(
65 OperationStreams::default(),
66 OperationIr::Init(InitOperationIr { out: desc }),
67 NoOp::<B>::new(),
68 );
69
70 out
71 }
72
73 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
74 tensor.client.device().clone()
75 }
76
77 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
78 let device_original: &B::Device = tensor.client.device();
79 let device_target: B::Device = device.clone();
80
81 if device_original == &device_target {
82 return tensor;
83 }
84
85 let id = tensor.stream;
86 let client_target = get_client::<B>(&device_target);
87 let client_original = tensor.client.clone();
88
89 client_original
90 .clone()
91 .change_client_int::<B>(tensor.into_ir(), client_target, id)
92 }
93
94 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
95 if tensor.shape == shape {
96 return tensor;
97 }
98
99 #[derive(new, Debug)]
100 struct ReshapeDimsOps<B: FusionBackend> {
101 desc: UnaryOpIr,
102 _b: PhantomData<B>,
103 }
104
105 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
106 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
107 let input = handles.get_int_tensor::<B>(&self.desc.input);
108 let output = B::int_reshape(input, self.desc.out.shape.clone());
109 handles.register_int_tensor::<B>(&self.desc.out.id, output);
110 }
111 }
112
113 let dtype = tensor.dtype;
114 let mut streams = OperationStreams::default();
115 streams.tensor(&tensor);
116 let out = tensor.client.tensor_uninitialized(shape, dtype);
117
118 let desc = UnaryOpIr {
119 input: tensor.into_ir(),
120 out: out.to_ir_out(),
121 };
122 out.client.register(
123 streams,
124 OperationIr::BaseInt(BaseOperationIr::Reshape(desc.clone())),
125 ReshapeDimsOps::<B>::new(desc),
126 );
127
128 out
129 }
130
131 fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
132 #[derive(new, Debug)]
133 struct SliceOps<B: FusionBackend> {
134 desc: SliceOpIr,
135 _b: PhantomData<B>,
136 }
137
138 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
139 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
140 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
141
142 let output = B::int_slice(tensor, self.desc.ranges.as_slice());
143
144 handles.register_int_tensor::<B>(&self.desc.out.id, output);
145 }
146 }
147
148 let mut streams = OperationStreams::default();
149 streams.tensor(&tensor);
150
151 let dtype = tensor.dtype;
152 let shape = tensor.shape.clone().slice(slices).unwrap();
153 let out = tensor.client.tensor_uninitialized(shape, dtype);
154
155 let desc = SliceOpIr {
156 tensor: tensor.into_ir(),
157 ranges: slices.to_vec(),
158 out: out.to_ir_out(),
159 };
160 out.client.register(
161 streams,
162 OperationIr::BaseInt(BaseOperationIr::Slice(desc.clone())),
163 SliceOps::<B>::new(desc),
164 );
165
166 out
167 }
168
169 fn int_slice_assign(
170 tensor: IntTensor<Self>,
171 ranges: &[burn_tensor::Slice],
172 value: IntTensor<Self>,
173 ) -> IntTensor<Self> {
174 #[derive(new, Debug)]
175 struct SliceAssignOps<B: FusionBackend> {
176 desc: SliceAssignOpIr,
177 _b: PhantomData<B>,
178 }
179
180 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
181 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
182 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
183 let value = handles.get_int_tensor::<B>(&self.desc.value);
184
185 let output = B::int_slice_assign(tensor, self.desc.ranges.as_slice(), value);
186
187 handles.register_int_tensor::<B>(&self.desc.out.id, output);
188 }
189 }
190
191 let mut streams = OperationStreams::default();
192 streams.tensor(&tensor);
193 streams.tensor(&value);
194
195 let dtype = tensor.dtype;
196 let shape = tensor.shape.clone();
197 let out = tensor.client.tensor_uninitialized(shape, dtype);
198 let desc = SliceAssignOpIr {
199 tensor: tensor.into_ir(),
200 ranges: ranges.to_vec(),
201 value: value.into_ir(),
202 out: out.to_ir_out(),
203 };
204 out.client.register(
205 streams,
206 OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc.clone())),
207 SliceAssignOps::<B>::new(desc),
208 );
209
210 out
211 }
212
213 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
214 binary_int_ops!(MatmulOps, B::int_matmul);
215
216 let mut streams = OperationStreams::default();
217 streams.tensor(&lhs);
218 streams.tensor(&rhs);
219 let dtype = lhs.dtype;
220 let shape = Shape::matmul(&lhs.shape, &rhs.shape).unwrap();
221
222 let out = lhs.client.tensor_uninitialized(shape, dtype);
223 let desc = BinaryOpIr {
224 lhs: lhs.into_ir(),
225 rhs: rhs.into_ir(),
226 out: out.to_ir_out(),
227 };
228
229 out.client.register(
230 streams,
231 OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
232 MatmulOps::<B>::new(desc),
233 );
234
235 out
236 }
237
238 fn int_mask_where(
239 tensor: IntTensor<Self>,
240 mask: BoolTensor<Self>,
241 value: IntTensor<Self>,
242 ) -> IntTensor<Self> {
243 #[derive(new, Debug)]
244 struct MaskWhereOps<B: FusionBackend> {
245 desc: MaskWhereOpIr,
246 _b: PhantomData<B>,
247 }
248
249 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
250 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
251 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
252 let value = handles.get_int_tensor::<B>(&self.desc.value);
253 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
254
255 let output = B::int_mask_where(tensor, mask, value);
256
257 handles.register_int_tensor::<B>(&self.desc.out.id, output);
258 }
259 }
260
261 let mut streams = OperationStreams::default();
262 streams.tensor(&tensor);
263 streams.tensor(&value);
264 streams.tensor(&mask);
265
266 let dtype = tensor.dtype;
267 let out = tensor
268 .client
269 .tensor_uninitialized(tensor.shape.broadcast(&mask.shape).unwrap(), dtype);
270
271 let desc = MaskWhereOpIr {
272 tensor: tensor.into_ir(),
273 value: value.into_ir(),
274 mask: mask.into_ir(),
275 out: out.to_ir_out(),
276 };
277 out.client.register(
278 streams,
279 OperationIr::NumericInt(dtype, NumericOperationIr::MaskWhere(desc.clone())),
280 MaskWhereOps::<B>::new(desc),
281 );
282
283 out
284 }
285
286 fn int_mask_fill(
287 tensor: IntTensor<Self>,
288 mask: BoolTensor<Self>,
289 value: IntElem<Self>,
290 ) -> IntTensor<Self> {
291 #[derive(new, Debug)]
292 struct MaskFillOps<B: FusionBackend> {
293 desc: MaskFillOpIr,
294 _b: PhantomData<B>,
295 }
296
297 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
298 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
299 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
300 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
301
302 let output = B::int_mask_fill(tensor, mask, self.desc.value.elem());
303
304 handles.register_int_tensor::<B>(&self.desc.out.id, output);
305 }
306 }
307
308 let mut streams = OperationStreams::default();
309 streams.tensor(&tensor);
310 streams.tensor(&mask);
311
312 let dtype = tensor.dtype;
313 let shape = tensor.shape.clone();
314 let out = tensor.client.tensor_uninitialized(shape, dtype);
315 let desc = MaskFillOpIr {
316 tensor: tensor.into_ir(),
317 value: ScalarIr::with_dtype(value, &dtype),
318 mask: mask.into_ir(),
319 out: out.to_ir_out(),
320 };
321 out.client.register(
322 streams,
323 OperationIr::NumericInt(dtype, NumericOperationIr::MaskFill(desc.clone())),
324 MaskFillOps::<B>::new(desc),
325 );
326
327 out
328 }
329
330 fn int_gather(
331 dim: usize,
332 tensor: IntTensor<Self>,
333 indices: IntTensor<Self>,
334 ) -> IntTensor<Self> {
335 #[derive(new, Debug)]
336 struct GatherOps<B: FusionBackend> {
337 desc: GatherOpIr,
338 _b: PhantomData<B>,
339 }
340
341 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
342 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
343 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
344 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
345
346 let output = B::int_gather(self.desc.dim, tensor, indices);
347 handles.register_int_tensor::<B>(&self.desc.out.id, output);
348 }
349 }
350
351 let mut streams = OperationStreams::default();
352 streams.tensor(&tensor);
353 streams.tensor(&indices);
354
355 let dtype = tensor.dtype;
356 let shape = indices.shape.clone();
357 let out = tensor.client.tensor_uninitialized(shape, dtype);
358 let desc = GatherOpIr {
359 tensor: tensor.into_ir(),
360 dim,
361 indices: indices.into_ir(),
362 out: out.to_ir_out(),
363 };
364 out.client.register(
365 streams,
366 OperationIr::NumericInt(dtype, NumericOperationIr::Gather(desc.clone())),
367 GatherOps::<B>::new(desc),
368 );
369
370 out
371 }
372
373 fn int_scatter(
374 dim: usize,
375 tensor: IntTensor<Self>,
376 indices: IntTensor<Self>,
377 value: IntTensor<Self>,
378 ) -> IntTensor<Self> {
379 #[derive(new, Debug)]
380 struct ScatterOps<B: FusionBackend> {
381 desc: ScatterOpIr,
382 _b: PhantomData<B>,
383 }
384
385 impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
386 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
387 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
388 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
389 let value = handles.get_int_tensor::<B>(&self.desc.value);
390
391 let output = B::int_scatter(self.desc.dim, tensor, indices, value);
392
393 handles.register_int_tensor::<B>(&self.desc.out.id, output);
394 }
395 }
396
397 let mut streams = OperationStreams::default();
398 streams.tensor(&tensor);
399 streams.tensor(&indices);
400 streams.tensor(&value);
401
402 let dtype = tensor.dtype;
403 let shape = tensor.shape.clone();
404 let out = tensor.client.tensor_uninitialized(shape, dtype);
405 let desc = ScatterOpIr {
406 tensor: tensor.into_ir(),
407 dim,
408 indices: indices.into_ir(),
409 value: value.into_ir(),
410 out: out.to_ir_out(),
411 };
412 out.client.register(
413 streams,
414 OperationIr::NumericInt(dtype, NumericOperationIr::Scatter(desc.clone())),
415 ScatterOps::<B>::new(desc),
416 );
417
418 out
419 }
420
421 fn int_select(
422 tensor: IntTensor<Self>,
423 dim: usize,
424 indices: IntTensor<Self>,
425 ) -> IntTensor<Self> {
426 #[derive(new, Debug)]
427 struct SelectOps<B: FusionBackend> {
428 desc: SelectOpIr,
429 _b: PhantomData<B>,
430 }
431
432 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
433 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
434 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
435 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
436
437 let output = B::int_select(tensor, self.desc.dim, indices);
438
439 handles.register_int_tensor::<B>(&self.desc.out.id, output);
440 }
441 }
442
443 let mut streams = OperationStreams::default();
444 streams.tensor(&tensor);
445 streams.tensor(&indices);
446
447 let dtype = tensor.dtype;
448 let mut shape = tensor.shape.clone();
449 shape[dim] = indices.shape[0];
450 let out = tensor.client.tensor_uninitialized(shape, dtype);
451 let desc = SelectOpIr {
452 tensor: tensor.into_ir(),
453 dim,
454 indices: indices.into_ir(),
455 out: out.to_ir_out(),
456 };
457 out.client.register(
458 streams,
459 OperationIr::NumericInt(dtype, NumericOperationIr::Select(desc.clone())),
460 SelectOps::<B>::new(desc),
461 );
462
463 out
464 }
465
466 fn int_select_assign(
467 tensor: IntTensor<Self>,
468 dim: usize,
469 indices: IntTensor<Self>,
470 value: IntTensor<Self>,
471 ) -> IntTensor<Self> {
472 #[derive(new, Debug)]
473 struct SelectAssignOps<B: FusionBackend> {
474 desc: SelectAssignOpIr,
475 _b: PhantomData<B>,
476 }
477
478 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
479 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
480 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
481 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
482 let value = handles.get_int_tensor::<B>(&self.desc.value);
483
484 let output = B::int_select_assign(tensor, self.desc.dim, indices, value);
485
486 handles.register_int_tensor::<B>(&self.desc.out.id, output);
487 }
488 }
489
490 let mut streams = OperationStreams::default();
491 streams.tensor(&tensor);
492 streams.tensor(&indices);
493 streams.tensor(&value);
494
495 let dtype = tensor.dtype;
496 let shape = tensor.shape.clone();
497 let out = tensor.client.tensor_uninitialized(shape, dtype);
498 let desc = SelectAssignOpIr {
499 tensor: tensor.into_ir(),
500 dim,
501 indices: indices.into_ir(),
502 value: value.into_ir(),
503 out: out.to_ir_out(),
504 };
505 out.client.register(
506 streams,
507 OperationIr::NumericInt(dtype, NumericOperationIr::SelectAssign(desc.clone())),
508 SelectAssignOps::<B>::new(desc),
509 );
510
511 out
512 }
513
514 fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
515 #[derive(new, Debug)]
516 struct CatOps<B: FusionBackend> {
517 desc: CatOpIr,
518 _b: PhantomData<B>,
519 }
520
521 impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
522 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
523 let tensors = self
524 .desc
525 .tensors
526 .iter()
527 .map(|tensor| handles.get_int_tensor::<B>(tensor))
528 .collect();
529
530 let output = B::int_cat(tensors, self.desc.dim);
531
532 handles.register_int_tensor::<B>(&self.desc.out.id, output);
533 }
534 }
535
536 let tensor_first = tensors.first().unwrap();
537 let client = tensor_first.client.clone();
538
539 let mut streams = OperationStreams::default();
540 tensors.iter().for_each(|tensor| streams.tensor(tensor));
541
542 let shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap();
544
545 let dtype = tensor_first.dtype;
546 let out = client.tensor_uninitialized(shape, dtype);
547
548 let desc = CatOpIr {
549 tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
550 dim,
551 out: out.to_ir_out(),
552 };
553 client.register(
554 streams,
555 OperationIr::BaseInt(BaseOperationIr::Cat(desc.clone())),
556 CatOps::<B>::new(desc),
557 );
558
559 out
560 }
561
562 fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
563 binary_int_cmp_ops!(EqualOps, B::int_equal);
564
565 let mut streams = OperationStreams::default();
566 streams.tensor(&lhs);
567 streams.tensor(&rhs);
568 let out = lhs.client.tensor_uninitialized(
569 lhs.shape.broadcast(&rhs.shape).unwrap(),
570 B::BoolElem::dtype(),
571 );
572
573 let desc = BinaryOpIr {
574 lhs: lhs.into_ir(),
575 rhs: rhs.into_ir(),
576 out: out.to_ir_out(),
577 };
578 out.client.register(
579 streams,
580 OperationIr::BaseInt(BaseOperationIr::Equal(desc.clone())),
581 EqualOps::<B>::new(desc),
582 );
583
584 out
585 }
586
587 fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
588 scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem);
589
590 let mut streams = OperationStreams::default();
591 streams.tensor(&lhs);
592 let out = lhs
593 .client
594 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
595
596 let dtype = lhs.dtype;
597 let desc = ScalarOpIr {
598 lhs: lhs.into_ir(),
599 rhs: ScalarIr::with_dtype(rhs, &dtype),
600 out: out.to_ir_out(),
601 };
602 out.client.register(
603 streams,
604 OperationIr::NumericInt(dtype, NumericOperationIr::EqualElem(desc.clone())),
605 EqualElemOps::<B>::new(desc),
606 );
607
608 out
609 }
610
611 fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
612 binary_int_cmp_ops!(GreaterOps, B::int_greater);
613
614 let mut streams = OperationStreams::default();
615 streams.tensor(&lhs);
616 streams.tensor(&rhs);
617 let out = lhs.client.tensor_uninitialized(
618 lhs.shape.broadcast(&rhs.shape).unwrap(),
619 B::BoolElem::dtype(),
620 );
621
622 let dtype = lhs.dtype;
623 let desc = BinaryOpIr {
624 lhs: lhs.into_ir(),
625 rhs: rhs.into_ir(),
626 out: out.to_ir_out(),
627 };
628 out.client.register(
629 streams,
630 OperationIr::NumericInt(dtype, NumericOperationIr::Greater(desc.clone())),
631 GreaterOps::<B>::new(desc),
632 );
633
634 out
635 }
636
637 fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
638 scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem);
639
640 let mut streams = OperationStreams::default();
641 streams.tensor(&lhs);
642 let out = lhs
643 .client
644 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
645
646 let dtype = lhs.dtype;
647 let desc = ScalarOpIr {
648 lhs: lhs.into_ir(),
649 rhs: ScalarIr::with_dtype(rhs, &dtype),
650 out: out.to_ir_out(),
651 };
652 out.client.register(
653 streams,
654 OperationIr::NumericInt(dtype, NumericOperationIr::GreaterElem(desc.clone())),
655 GreaterElemOps::<B>::new(desc),
656 );
657
658 out
659 }
660
661 fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
662 binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal);
663
664 let mut streams = OperationStreams::default();
665 streams.tensor(&lhs);
666 streams.tensor(&rhs);
667 let out = lhs.client.tensor_uninitialized(
668 lhs.shape.broadcast(&rhs.shape).unwrap(),
669 B::BoolElem::dtype(),
670 );
671
672 let dtype = lhs.dtype;
673 let desc = BinaryOpIr {
674 lhs: lhs.into_ir(),
675 rhs: rhs.into_ir(),
676 out: out.to_ir_out(),
677 };
678 out.client.register(
679 streams,
680 OperationIr::NumericInt(dtype, NumericOperationIr::GreaterEqual(desc.clone())),
681 GreaterEqualOps::<B>::new(desc),
682 );
683
684 out
685 }
686
687 fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
688 scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem);
689
690 let mut streams = OperationStreams::default();
691 streams.tensor(&lhs);
692 let out = lhs
693 .client
694 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
695
696 let dtype = lhs.dtype;
697 let desc = ScalarOpIr {
698 lhs: lhs.into_ir(),
699 rhs: ScalarIr::with_dtype(rhs, &dtype),
700 out: out.to_ir_out(),
701 };
702 out.client.register(
703 streams,
704 OperationIr::NumericInt(dtype, NumericOperationIr::GreaterEqualElem(desc.clone())),
705 GreaterEqualElemOps::<B>::new(desc),
706 );
707
708 out
709 }
710
711 fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
712 binary_int_cmp_ops!(LowerOps, B::int_lower);
713
714 let mut streams = OperationStreams::default();
715 streams.tensor(&lhs);
716 streams.tensor(&rhs);
717 let out = lhs.client.tensor_uninitialized(
718 lhs.shape.broadcast(&rhs.shape).unwrap(),
719 B::BoolElem::dtype(),
720 );
721
722 let dtype = lhs.dtype;
723 let desc = BinaryOpIr {
724 lhs: lhs.into_ir(),
725 rhs: rhs.into_ir(),
726 out: out.to_ir_out(),
727 };
728 out.client.register(
729 streams,
730 OperationIr::NumericInt(dtype, NumericOperationIr::Lower(desc.clone())),
731 LowerOps::<B>::new(desc),
732 );
733
734 out
735 }
736
737 fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
738 scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem);
739
740 let mut streams = OperationStreams::default();
741 streams.tensor(&lhs);
742 let out = lhs
743 .client
744 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
745
746 let dtype = lhs.dtype;
747 let desc = ScalarOpIr {
748 lhs: lhs.into_ir(),
749 rhs: ScalarIr::with_dtype(rhs, &dtype),
750 out: out.to_ir_out(),
751 };
752 out.client.register(
753 streams,
754 OperationIr::NumericInt(dtype, NumericOperationIr::LowerElem(desc.clone())),
755 LowerElemOps::<B>::new(desc),
756 );
757
758 out
759 }
760
761 fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
762 binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal);
763
764 let mut streams = OperationStreams::default();
765 streams.tensor(&lhs);
766 streams.tensor(&rhs);
767 let out = lhs.client.tensor_uninitialized(
768 lhs.shape.broadcast(&rhs.shape).unwrap(),
769 B::BoolElem::dtype(),
770 );
771
772 let dtype = lhs.dtype;
773 let desc = BinaryOpIr {
774 lhs: lhs.into_ir(),
775 rhs: rhs.into_ir(),
776 out: out.to_ir_out(),
777 };
778 out.client.register(
779 streams,
780 OperationIr::NumericInt(dtype, NumericOperationIr::LowerEqual(desc.clone())),
781 LowerEqualOps::<B>::new(desc),
782 );
783
784 out
785 }
786
787 fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
788 scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem);
789
790 let mut streams = OperationStreams::default();
791 streams.tensor(&lhs);
792 let out = lhs
793 .client
794 .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
795
796 let dtype = lhs.dtype;
797 let desc = ScalarOpIr {
798 lhs: lhs.into_ir(),
799 rhs: ScalarIr::with_dtype(rhs, &dtype),
800 out: out.to_ir_out(),
801 };
802 out.client.register(
803 streams,
804 OperationIr::NumericInt(dtype, NumericOperationIr::LowerEqualElem(desc.clone())),
805 LowerEqualElemOps::<B>::new(desc),
806 );
807
808 out
809 }
810
811 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
812 binary_int_ops!(AddOps, B::int_add);
813
814 let dtype = lhs.dtype;
815 let mut streams = OperationStreams::default();
816 streams.tensor(&lhs);
817 streams.tensor(&rhs);
818 let out = lhs
819 .client
820 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
821
822 let desc = BinaryOpIr {
823 lhs: lhs.into_ir(),
824 rhs: rhs.into_ir(),
825 out: out.to_ir_out(),
826 };
827 out.client.register(
828 streams,
829 OperationIr::NumericInt(dtype, NumericOperationIr::Add(desc.clone())),
830 AddOps::<B>::new(desc),
831 );
832
833 out
834 }
835
836 fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
837 scalar_int_ops!(AddOps, B::int_add_scalar);
838
839 let dtype = lhs.dtype;
840 let mut streams = OperationStreams::default();
841 streams.tensor(&lhs);
842 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
843
844 let desc = ScalarOpIr {
845 lhs: lhs.into_ir(),
846 rhs: ScalarIr::with_dtype(rhs, &dtype),
847 out: out.to_ir_out(),
848 };
849 out.client.register(
850 streams,
851 OperationIr::NumericInt(dtype, NumericOperationIr::AddScalar(desc.clone())),
852 AddOps::<B>::new(desc),
853 );
854
855 out
856 }
857
858 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
859 binary_int_ops!(SubOps, B::int_sub);
860
861 let dtype = lhs.dtype;
862 let mut streams = OperationStreams::default();
863 streams.tensor(&lhs);
864 streams.tensor(&rhs);
865 let out = lhs
866 .client
867 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
868
869 let desc = BinaryOpIr {
870 lhs: lhs.into_ir(),
871 rhs: rhs.into_ir(),
872 out: out.to_ir_out(),
873 };
874 out.client.register(
875 streams,
876 OperationIr::NumericInt(dtype, NumericOperationIr::Sub(desc.clone())),
877 SubOps::<B>::new(desc),
878 );
879
880 out
881 }
882
883 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
884 scalar_int_ops!(SubOps, B::int_sub_scalar);
885
886 let dtype = lhs.dtype;
887 let mut streams = OperationStreams::default();
888 streams.tensor(&lhs);
889 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
890
891 let desc = ScalarOpIr {
892 lhs: lhs.into_ir(),
893 rhs: ScalarIr::with_dtype(rhs, &dtype),
894 out: out.to_ir_out(),
895 };
896 out.client.register(
897 streams,
898 OperationIr::NumericInt(dtype, NumericOperationIr::SubScalar(desc.clone())),
899 SubOps::<B>::new(desc),
900 );
901
902 out
903 }
904
905 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
906 binary_int_ops!(MulOps, B::int_mul);
907
908 let dtype = lhs.dtype;
909 let mut streams = OperationStreams::default();
910 streams.tensor(&lhs);
911 streams.tensor(&rhs);
912 let out = lhs
913 .client
914 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
915
916 let desc = BinaryOpIr {
917 lhs: lhs.into_ir(),
918 rhs: rhs.into_ir(),
919 out: out.to_ir_out(),
920 };
921 out.client.register(
922 streams,
923 OperationIr::NumericInt(dtype, NumericOperationIr::Mul(desc.clone())),
924 MulOps::<B>::new(desc),
925 );
926
927 out
928 }
929
930 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
931 scalar_int_ops!(MulOps, B::int_mul_scalar);
932
933 let dtype = lhs.dtype;
934 let mut streams = OperationStreams::default();
935 streams.tensor(&lhs);
936 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
937
938 let desc = ScalarOpIr {
939 lhs: lhs.into_ir(),
940 rhs: ScalarIr::with_dtype(rhs, &dtype),
941 out: out.to_ir_out(),
942 };
943 out.client.register(
944 streams,
945 OperationIr::NumericInt(dtype, NumericOperationIr::MulScalar(desc.clone())),
946 MulOps::<B>::new(desc),
947 );
948
949 out
950 }
951
952 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
953 binary_int_ops!(DivOps, B::int_div);
954
955 let dtype = lhs.dtype;
956 let mut streams = OperationStreams::default();
957 streams.tensor(&lhs);
958 streams.tensor(&rhs);
959 let out = lhs
960 .client
961 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
962
963 let desc = BinaryOpIr {
964 lhs: lhs.into_ir(),
965 rhs: rhs.into_ir(),
966 out: out.to_ir_out(),
967 };
968 out.client.register(
969 streams,
970 OperationIr::NumericInt(dtype, NumericOperationIr::Div(desc.clone())),
971 DivOps::<B>::new(desc),
972 );
973
974 out
975 }
976
977 fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
978 scalar_int_ops!(DivOps, B::int_div_scalar);
979
980 let dtype = lhs.dtype;
981 let mut streams = OperationStreams::default();
982 streams.tensor(&lhs);
983 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
984
985 let desc = ScalarOpIr {
986 lhs: lhs.into_ir(),
987 rhs: ScalarIr::with_dtype(rhs, &dtype),
988 out: out.to_ir_out(),
989 };
990 out.client.register(
991 streams,
992 OperationIr::NumericInt(dtype, NumericOperationIr::DivScalar(desc.clone())),
993 DivOps::<B>::new(desc),
994 );
995
996 out
997 }
998
999 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1000 binary_int_ops!(ModOps, B::int_remainder);
1001
1002 let dtype = lhs.dtype;
1003 let mut streams = OperationStreams::default();
1004 streams.tensor(&lhs);
1005 streams.tensor(&rhs);
1006 let out = lhs
1007 .client
1008 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
1009
1010 let desc = BinaryOpIr {
1011 lhs: lhs.into_ir(),
1012 rhs: rhs.into_ir(),
1013 out: out.to_ir_out(),
1014 };
1015 out.client.register(
1016 streams,
1017 OperationIr::NumericInt(dtype, NumericOperationIr::Rem(desc.clone())),
1018 ModOps::<B>::new(desc),
1019 );
1020
1021 out
1022 }
1023
1024 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
1025 scalar_int_ops!(ModOps, B::int_remainder_scalar);
1026
1027 let dtype = lhs.dtype;
1028 let mut streams = OperationStreams::default();
1029 streams.tensor(&lhs);
1030 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
1031
1032 let desc = ScalarOpIr {
1033 lhs: lhs.into_ir(),
1034 rhs: ScalarIr::with_dtype(rhs, &dtype),
1035 out: out.to_ir_out(),
1036 };
1037 out.client.register(
1038 streams,
1039 OperationIr::NumericInt(dtype, NumericOperationIr::RemScalar(desc.clone())),
1040 ModOps::<B>::new(desc),
1041 );
1042
1043 out
1044 }
1045
1046 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1047 #[derive(new, Debug)]
1048 struct ZerosOps<B: FusionBackend> {
1049 desc: TensorIr,
1050 device: Device<B>,
1051 }
1052
1053 impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
1054 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1055 let shape = self.desc.shape.clone();
1056 let output = B::int_zeros(shape, &self.device, self.desc.dtype.into());
1057 handles.register_int_tensor::<B>(&self.desc.id, output);
1058 }
1059 }
1060
1061 let dtype = dtype.into();
1062 let client = get_client::<B>(&device.clone());
1063 let out = client.tensor_uninitialized(shape, dtype);
1064 let desc = out.to_ir_out();
1065 client.register(
1066 OperationStreams::default(),
1067 OperationIr::NumericInt(dtype, NumericOperationIr::Zeros(desc.clone())),
1068 ZerosOps::<B>::new(desc, device.clone()),
1069 );
1070
1071 out
1072 }
1073
1074 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1075 #[derive(new, Debug)]
1076 struct OnesOps<B: FusionBackend> {
1077 desc: TensorIr,
1078 device: Device<B>,
1079 }
1080
1081 impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
1082 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1083 let shape = self.desc.shape.clone();
1084 let output = B::int_ones(shape, &self.device, self.desc.dtype.into());
1085 handles.register_int_tensor::<B>(&self.desc.id, output);
1086 }
1087 }
1088
1089 let dtype = dtype.into();
1090 let client = get_client::<B>(&device.clone());
1091 let out = client.tensor_uninitialized(shape, dtype);
1092
1093 let desc = out.to_ir_out();
1094 client.register(
1095 OperationStreams::default(),
1096 OperationIr::NumericInt(dtype, NumericOperationIr::Ones(desc.clone())),
1097 OnesOps::<B>::new(desc, device.clone()),
1098 );
1099
1100 out
1101 }
1102
1103 fn int_full(
1104 shape: Shape,
1105 fill_value: IntElem<Self>,
1106 device: &Device<Self>,
1107 dtype: IntDType,
1108 ) -> IntTensor<Self> {
1109 #[derive(new, Debug)]
1110 struct FullOps<B: FusionBackend> {
1111 out: TensorIr,
1112 elem: ScalarIr,
1113 device: Device<B>,
1114 }
1115
1116 impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
1117 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1118 let shape = self.out.shape.clone();
1119 let output =
1120 B::int_full(shape, self.elem.elem(), &self.device, self.out.dtype.into());
1121 handles.register_int_tensor::<B>(&self.out.id, output);
1122 }
1123 }
1124
1125 let dtype = dtype.into();
1126 let client = get_client::<B>(&device.clone());
1127 let out = client.tensor_uninitialized(shape, dtype);
1128
1129 let desc = (out.to_ir_out(), ScalarIr::with_dtype(fill_value, &dtype));
1130 client.register(
1131 OperationStreams::default(),
1132 OperationIr::NumericInt(dtype, NumericOperationIr::Full(desc.clone())),
1133 FullOps::<B>::new(desc.0, desc.1, device.clone()),
1134 );
1135
1136 out
1137 }
1138
1139 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
1140 unary_int_ops!(SumOps, B::int_sum, reduce);
1141
1142 let dtype = tensor.dtype;
1143 let mut streams = OperationStreams::default();
1144 streams.tensor(&tensor);
1145 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1146
1147 let desc = UnaryOpIr {
1148 input: tensor.into_ir(),
1149 out: out.to_ir_out(),
1150 };
1151 out.client.register(
1152 streams,
1153 OperationIr::NumericInt(dtype, NumericOperationIr::Sum(desc.clone())),
1154 SumOps::<B>::new(desc),
1155 );
1156
1157 out
1158 }
1159
1160 fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
1161 reduce_int_ops!(SumDimOps, B::int_sum_dim);
1162
1163 let dtype = tensor.dtype;
1164 let mut streams = OperationStreams::default();
1165 streams.tensor(&tensor);
1166 let mut shape = tensor.shape.clone();
1167 shape[axis] = 1;
1168 let out = tensor.client.tensor_uninitialized(shape, dtype);
1169
1170 let desc = ReduceDimOpIr {
1171 out: out.to_ir_out(),
1172 input: tensor.into_ir(),
1173 axis,
1174 };
1175 out.client.register(
1176 streams,
1177 OperationIr::NumericInt(dtype, NumericOperationIr::SumDim(desc.clone())),
1178 SumDimOps::<B>::new(desc),
1179 );
1180
1181 out
1182 }
1183
1184 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
1185 unary_int_ops!(ProdOps, B::int_prod, reduce);
1186
1187 let dtype = tensor.dtype;
1188 let mut streams = OperationStreams::default();
1189 streams.tensor(&tensor);
1190 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1191
1192 let desc = UnaryOpIr {
1193 input: tensor.into_ir(),
1194 out: out.to_ir_out(),
1195 };
1196 out.client.register(
1197 streams,
1198 OperationIr::NumericInt(dtype, NumericOperationIr::Prod(desc.clone())),
1199 ProdOps::<B>::new(desc),
1200 );
1201
1202 out
1203 }
1204
1205 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1206 reduce_int_ops!(ProdDimOps, B::int_prod_dim);
1207
1208 let dtype = tensor.dtype;
1209 let mut streams = OperationStreams::default();
1210 streams.tensor(&tensor);
1211 let mut shape = tensor.shape.clone();
1212 shape[dim] = 1;
1213 let out = tensor.client.tensor_uninitialized(shape, dtype);
1214
1215 let desc = ReduceDimOpIr {
1216 input: tensor.into_ir(),
1217 axis: dim,
1218 out: out.to_ir_out(),
1219 };
1220 out.client.register(
1221 streams,
1222 OperationIr::NumericInt(dtype, NumericOperationIr::ProdDim(desc.clone())),
1223 ProdDimOps::<B>::new(desc),
1224 );
1225
1226 out
1227 }
1228
1229 fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
1230 unary_int_ops!(MeanOps, B::int_mean, reduce);
1231
1232 let dtype = tensor.dtype;
1233 let mut streams = OperationStreams::default();
1234 streams.tensor(&tensor);
1235 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1236
1237 let desc = UnaryOpIr {
1238 input: tensor.into_ir(),
1239 out: out.to_ir_out(),
1240 };
1241 out.client.register(
1242 streams,
1243 OperationIr::NumericInt(dtype, NumericOperationIr::Mean(desc.clone())),
1244 MeanOps::<B>::new(desc),
1245 );
1246
1247 out
1248 }
1249
1250 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1251 reduce_int_ops!(MeanDimOps, B::int_mean_dim);
1252
1253 let dtype = tensor.dtype;
1254 let mut streams = OperationStreams::default();
1255 streams.tensor(&tensor);
1256 let mut shape = tensor.shape.clone();
1257 shape[dim] = 1;
1258 let out = tensor.client.tensor_uninitialized(shape, dtype);
1259
1260 let desc = ReduceDimOpIr {
1261 input: tensor.into_ir(),
1262 axis: dim,
1263 out: out.to_ir_out(),
1264 };
1265 out.client.register(
1266 streams,
1267 OperationIr::NumericInt(dtype, NumericOperationIr::MeanDim(desc.clone())),
1268 MeanDimOps::<B>::new(desc),
1269 );
1270
1271 out
1272 }
1273
1274 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1275 #[derive(new, Debug)]
1276 struct CumsumOps<B: FusionBackend> {
1277 desc: DimOpIr,
1278 _b: PhantomData<B>,
1279 }
1280
1281 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1282 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1283 let input = handles.get_int_tensor::<B>(&self.desc.input);
1284 let output = B::int_cumsum(input, self.desc.axis);
1285 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1286 }
1287 }
1288
1289 let dtype = tensor.dtype;
1290 let mut streams = OperationStreams::default();
1291 streams.tensor(&tensor);
1292 let shape = tensor.shape.clone();
1293 let out = tensor.client.tensor_uninitialized(shape, dtype);
1294
1295 let desc = DimOpIr {
1296 out: out.to_ir_out(),
1297 input: tensor.into_ir(),
1298 axis: dim,
1299 };
1300 out.client.register(
1301 streams,
1302 OperationIr::BaseInt(BaseOperationIr::CumSum(desc.clone())),
1303 CumsumOps::<B>::new(desc),
1304 );
1305
1306 out
1307 }
1308
1309 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1310 #[derive(new, Debug)]
1311 struct CumprodOps<B: FusionBackend> {
1312 desc: DimOpIr,
1313 _b: PhantomData<B>,
1314 }
1315
1316 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1317 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1318 let input = handles.get_int_tensor::<B>(&self.desc.input);
1319 let output = B::int_cumprod(input, self.desc.axis);
1320 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1321 }
1322 }
1323
1324 let dtype = tensor.dtype;
1325 let mut streams = OperationStreams::default();
1326 streams.tensor(&tensor);
1327 let shape = tensor.shape.clone();
1328 let out = tensor.client.tensor_uninitialized(shape, dtype);
1329
1330 let desc = DimOpIr {
1331 out: out.to_ir_out(),
1332 input: tensor.into_ir(),
1333 axis: dim,
1334 };
1335 out.client.register(
1336 streams,
1337 OperationIr::BaseInt(BaseOperationIr::CumProd(desc.clone())),
1338 CumprodOps::<B>::new(desc),
1339 );
1340
1341 out
1342 }
1343
1344 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1345 #[derive(new, Debug)]
1346 struct CumminOps<B: FusionBackend> {
1347 desc: DimOpIr,
1348 _b: PhantomData<B>,
1349 }
1350
1351 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1352 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1353 let input = handles.get_int_tensor::<B>(&self.desc.input);
1354 let output = B::int_cummin(input, self.desc.axis);
1355 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1356 }
1357 }
1358
1359 let dtype = tensor.dtype;
1360 let mut streams = OperationStreams::default();
1361 streams.tensor(&tensor);
1362 let shape = tensor.shape.clone();
1363 let out = tensor.client.tensor_uninitialized(shape, dtype);
1364
1365 let desc = DimOpIr {
1366 out: out.to_ir_out(),
1367 input: tensor.into_ir(),
1368 axis: dim,
1369 };
1370 out.client.register(
1371 streams,
1372 OperationIr::BaseInt(BaseOperationIr::CumMin(desc.clone())),
1373 CumminOps::<B>::new(desc),
1374 );
1375
1376 out
1377 }
1378
1379 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1380 #[derive(new, Debug)]
1381 struct CummaxOps<B: FusionBackend> {
1382 desc: DimOpIr,
1383 _b: PhantomData<B>,
1384 }
1385
1386 impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1387 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1388 let input = handles.get_int_tensor::<B>(&self.desc.input);
1389 let output = B::int_cummax(input, self.desc.axis);
1390 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1391 }
1392 }
1393
1394 let dtype = tensor.dtype;
1395 let mut streams = OperationStreams::default();
1396 streams.tensor(&tensor);
1397 let shape = tensor.shape.clone();
1398 let out = tensor.client.tensor_uninitialized(shape, dtype);
1399
1400 let desc = DimOpIr {
1401 out: out.to_ir_out(),
1402 input: tensor.into_ir(),
1403 axis: dim,
1404 };
1405 out.client.register(
1406 streams,
1407 OperationIr::BaseInt(BaseOperationIr::CumMax(desc.clone())),
1408 CummaxOps::<B>::new(desc),
1409 );
1410
1411 out
1412 }
1413
1414 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1415 reduce_int_ops!(ArgMaxOps, B::int_argmax);
1416
1417 let dtype = tensor.dtype;
1418 let mut streams = OperationStreams::default();
1419 streams.tensor(&tensor);
1420 let mut shape = tensor.shape.clone();
1421 shape[dim] = 1;
1422 let out = tensor.client.tensor_uninitialized(shape, dtype);
1423
1424 let desc = ReduceDimOpIr {
1425 input: tensor.into_ir(),
1426 axis: dim,
1427 out: out.to_ir_out(),
1428 };
1429 out.client.register(
1430 streams,
1431 OperationIr::NumericInt(dtype, NumericOperationIr::ArgMax(desc.clone())),
1432 ArgMaxOps::<B>::new(desc),
1433 );
1434
1435 out
1436 }
1437
1438 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1439 reduce_int_ops!(ArgMinOps, B::int_argmin);
1440
1441 let dtype = tensor.dtype;
1442 let mut streams = OperationStreams::default();
1443 streams.tensor(&tensor);
1444 let mut shape = tensor.shape.clone();
1445 shape[dim] = 1;
1446 let out = tensor.client.tensor_uninitialized(shape, dtype);
1447
1448 let desc = ReduceDimOpIr {
1449 input: tensor.into_ir(),
1450 axis: dim,
1451 out: out.to_ir_out(),
1452 };
1453 out.client.register(
1454 streams,
1455 OperationIr::NumericInt(dtype, NumericOperationIr::ArgMin(desc.clone())),
1456 ArgMinOps::<B>::new(desc),
1457 );
1458
1459 out
1460 }
1461
1462 fn int_clamp(
1463 tensor: IntTensor<Self>,
1464 min: IntElem<Self>,
1465 max: IntElem<Self>,
1466 ) -> IntTensor<Self> {
1467 #[derive(new, Debug)]
1468 struct ClampOps<B: FusionBackend> {
1469 desc: ClampOpIr,
1470 _b: PhantomData<B>,
1471 }
1472
1473 impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
1474 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1475 let input = handles.get_int_tensor::<B>(&self.desc.tensor);
1476 let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem());
1477
1478 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1479 }
1480 }
1481
1482 let dtype = tensor.dtype;
1483 let mut streams = OperationStreams::default();
1484 streams.tensor(&tensor);
1485 let out = tensor
1486 .client
1487 .tensor_uninitialized(tensor.shape.clone(), dtype);
1488 let desc = ClampOpIr {
1489 tensor: tensor.into_ir(),
1490 min: ScalarIr::with_dtype(min, &dtype),
1491 max: ScalarIr::with_dtype(max, &dtype),
1492 out: out.to_ir_out(),
1493 };
1494 out.client.register(
1495 streams,
1496 OperationIr::NumericInt(dtype, NumericOperationIr::Clamp(desc.clone())),
1497 ClampOps::<B>::new(desc),
1498 );
1499
1500 out
1501 }
1502
1503 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1504 unary_int_ops!(AbsOps, B::int_abs);
1505
1506 let dtype = tensor.dtype;
1507 let mut streams = OperationStreams::default();
1508 streams.tensor(&tensor);
1509 let out = tensor
1510 .client
1511 .tensor_uninitialized(tensor.shape.clone(), dtype);
1512
1513 let desc = UnaryOpIr {
1514 input: tensor.into_ir(),
1515 out: out.to_ir_out(),
1516 };
1517 out.client.register(
1518 streams,
1519 OperationIr::NumericInt(dtype, NumericOperationIr::Abs(desc.clone())),
1520 AbsOps::<B>::new(desc),
1521 );
1522
1523 out
1524 }
1525
1526 fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
1527 #[derive(new, Debug)]
1528 struct IntoFloatOps<B: FusionBackend> {
1529 desc: UnaryOpIr,
1530 _b: PhantomData<B>,
1531 }
1532
1533 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
1534 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1535 let input = handles.get_int_tensor::<B>(&self.desc.input);
1536 let output = B::int_into_float(input);
1537 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1538 }
1539 }
1540
1541 let mut streams = OperationStreams::default();
1542 streams.tensor(&tensor);
1543 let out = tensor
1544 .client
1545 .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
1546 let desc = UnaryOpIr {
1547 input: tensor.into_ir(),
1548 out: out.to_ir_out(),
1549 };
1550 out.client.register(
1551 streams,
1552 OperationIr::Int(IntOperationIr::IntoFloat(desc.clone())),
1553 IntoFloatOps::<B>::new(desc),
1554 );
1555
1556 out
1557 }
1558
1559 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
1560 #[derive(new, Debug)]
1561 struct SwapDimsOps<B: FusionBackend> {
1562 desc: SwapDimsOpIr,
1563 _b: PhantomData<B>,
1564 }
1565
1566 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
1567 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1568 let input = handles.get_int_tensor::<B>(&self.desc.input);
1569 let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
1570 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1571 }
1572 }
1573
1574 let mut streams = OperationStreams::default();
1575 streams.tensor(&tensor);
1576 let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
1577 let dtype = tensor.dtype;
1578 let out = tensor.client.tensor_uninitialized(shape, dtype);
1579
1580 let desc = SwapDimsOpIr {
1581 input: tensor.into_ir(),
1582 dim1,
1583 dim2,
1584 out: out.to_ir_out(),
1585 };
1586 out.client.register(
1587 streams,
1588 OperationIr::BaseInt(BaseOperationIr::SwapDims(desc.clone())),
1589 SwapDimsOps::<B>::new(desc),
1590 );
1591
1592 out
1593 }
1594
1595 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
1596 unary_int_ops!(MaxOps, B::int_max, reduce);
1597
1598 let dtype = tensor.dtype;
1599 let mut streams = OperationStreams::default();
1600 streams.tensor(&tensor);
1601 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1602
1603 let desc = UnaryOpIr {
1604 input: tensor.into_ir(),
1605 out: out.to_ir_out(),
1606 };
1607 out.client.register(
1608 streams,
1609 OperationIr::NumericInt(dtype, NumericOperationIr::Max(desc.clone())),
1610 MaxOps::<B>::new(desc),
1611 );
1612
1613 out
1614 }
1615
1616 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1617 reduce_int_ops!(MaxDimOps, B::int_max_dim);
1618
1619 let dtype = tensor.dtype;
1620 let mut streams = OperationStreams::default();
1621 streams.tensor(&tensor);
1622 let mut shape = tensor.shape.clone();
1623 shape[dim] = 1;
1624 let out = tensor.client.tensor_uninitialized(shape, dtype);
1625
1626 let desc = ReduceDimOpIr {
1627 input: tensor.into_ir(),
1628 axis: dim,
1629 out: out.to_ir_out(),
1630 };
1631 out.client.register(
1632 streams,
1633 OperationIr::NumericInt(dtype, NumericOperationIr::MaxDim(desc.clone())),
1634 MaxDimOps::<B>::new(desc),
1635 );
1636
1637 out
1638 }
1639
1640 fn int_max_dim_with_indices(
1641 tensor: IntTensor<Self>,
1642 dim: usize,
1643 ) -> (IntTensor<Self>, IntTensor<Self>) {
1644 #[derive(new, Debug)]
1645 struct MaxDimWithIndicesOps<B: FusionBackend> {
1646 desc: ReduceDimWithIndicesOpIr,
1647 _b: PhantomData<B>,
1648 }
1649
1650 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
1651 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1652 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1653 let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
1654
1655 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1656 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1657 }
1658 }
1659
1660 let dtype = tensor.dtype;
1661 let mut streams = OperationStreams::default();
1662 streams.tensor(&tensor);
1663 let mut shape = tensor.shape.clone();
1664 shape[dim] = 1;
1665 let client = tensor.client.clone();
1666 let out = client.tensor_uninitialized(shape.clone(), dtype);
1667 let out_indices = client.tensor_uninitialized(shape, dtype);
1668 let desc = ReduceDimWithIndicesOpIr {
1669 tensor: tensor.into_ir(),
1670 dim,
1671 out: out.to_ir_out(),
1672 out_indices: out_indices.to_ir_out(),
1673 };
1674 client.register(
1675 streams,
1676 OperationIr::NumericInt(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())),
1677 MaxDimWithIndicesOps::<B>::new(desc),
1678 );
1679
1680 (out, out_indices)
1681 }
1682
1683 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
1684 unary_int_ops!(MinOps, B::int_min, reduce);
1685
1686 let dtype = tensor.dtype;
1687 let mut streams = OperationStreams::default();
1688 streams.tensor(&tensor);
1689 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1690
1691 let desc = UnaryOpIr {
1692 input: tensor.into_ir(),
1693 out: out.to_ir_out(),
1694 };
1695 out.client.register(
1696 streams,
1697 OperationIr::NumericInt(dtype, NumericOperationIr::Min(desc.clone())),
1698 MinOps::<B>::new(desc),
1699 );
1700
1701 out
1702 }
1703
1704 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1705 unary_int_ops!(MaxAbsOps, B::int_max_abs, reduce);
1706
1707 let dtype = tensor.dtype;
1708 let mut streams = OperationStreams::default();
1709 streams.tensor(&tensor);
1710 let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1711
1712 let desc = UnaryOpIr {
1713 input: tensor.into_ir(),
1714 out: out.to_ir_out(),
1715 };
1716 out.client.register(
1717 streams,
1718 OperationIr::NumericInt(dtype, NumericOperationIr::MaxAbs(desc.clone())),
1719 MaxAbsOps::<B>::new(desc),
1720 );
1721
1722 out
1723 }
1724
1725 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1726 reduce_int_ops!(MaxAbsDimOps, B::int_max_abs_dim);
1727
1728 let dtype = tensor.dtype;
1729 let mut streams = OperationStreams::default();
1730 streams.tensor(&tensor);
1731 let mut shape = tensor.shape.clone();
1732 shape[dim] = 1;
1733 let out = tensor.client.tensor_uninitialized(shape, dtype);
1734
1735 let desc = ReduceDimOpIr {
1736 input: tensor.into_ir(),
1737 axis: dim,
1738 out: out.to_ir_out(),
1739 };
1740 out.client.register(
1741 streams,
1742 OperationIr::NumericInt(dtype, NumericOperationIr::MaxAbsDim(desc.clone())),
1743 MaxAbsDimOps::<B>::new(desc),
1744 );
1745
1746 out
1747 }
1748
1749 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1750 reduce_int_ops!(MinDimOps, B::int_min_dim);
1751
1752 let dtype = tensor.dtype;
1753 let mut streams = OperationStreams::default();
1754 streams.tensor(&tensor);
1755 let mut shape = tensor.shape.clone();
1756 shape[dim] = 1;
1757 let out = tensor.client.tensor_uninitialized(shape, dtype);
1758
1759 let desc = ReduceDimOpIr {
1760 input: tensor.into_ir(),
1761 axis: dim,
1762 out: out.to_ir_out(),
1763 };
1764
1765 out.client.register(
1766 streams,
1767 OperationIr::NumericInt(dtype, NumericOperationIr::MinDim(desc.clone())),
1768 MinDimOps::<B>::new(desc),
1769 );
1770
1771 out
1772 }
1773
1774 fn int_min_dim_with_indices(
1775 tensor: IntTensor<Self>,
1776 dim: usize,
1777 ) -> (IntTensor<Self>, IntTensor<Self>) {
1778 #[derive(new, Debug)]
1779 struct MinDimWithIndicesOps<B: FusionBackend> {
1780 desc: ReduceDimWithIndicesOpIr,
1781 _b: PhantomData<B>,
1782 }
1783
1784 impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
1785 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1786 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1787 let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
1788
1789 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1790 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1791 }
1792 }
1793
1794 let dtype = tensor.dtype;
1795 let mut streams = OperationStreams::default();
1796 streams.tensor(&tensor);
1797 let mut shape = tensor.shape.clone();
1798 shape[dim] = 1;
1799 let client = tensor.client.clone();
1800 let out = client.tensor_uninitialized(shape.clone(), dtype);
1801 let out_indices = client.tensor_uninitialized(shape, dtype);
1802 let desc = ReduceDimWithIndicesOpIr {
1803 tensor: tensor.into_ir(),
1804 dim,
1805 out: out.to_ir_out(),
1806 out_indices: out_indices.to_ir_out(),
1807 };
1808 client.register(
1809 streams,
1810 OperationIr::NumericInt(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())),
1811 MinDimWithIndicesOps::<B>::new(desc),
1812 );
1813
1814 (out, out_indices)
1815 }
1816
1817 fn int_random(
1818 shape: Shape,
1819 distribution: Distribution,
1820 device: &Device<Self>,
1821 ) -> IntTensor<Self> {
1822 #[derive(new, Debug)]
1823 struct IntRandomOps<B: FusionBackend> {
1824 desc: RandomOpIr,
1825 device: Device<B>,
1826 }
1827
1828 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntRandomOps<B> {
1829 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1830 let shape = self.desc.out.shape.clone();
1831 let output = B::int_random(shape, self.desc.distribution, &self.device);
1832 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1833 }
1834 }
1835
1836 let client = get_client::<B>(&device.clone());
1837 let out = client.tensor_uninitialized(shape, B::IntElem::dtype());
1838
1839 let desc = RandomOpIr {
1840 out: out.to_ir_out(),
1841 distribution,
1842 };
1843 client.register(
1844 OperationStreams::default(),
1845 OperationIr::NumericInt(
1846 IntElem::<Self>::dtype(),
1847 NumericOperationIr::IntRandom(desc.clone()),
1848 ),
1849 IntRandomOps::<B>::new(desc, device.clone()),
1850 );
1851
1852 out
1853 }
1854
1855 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1856 #[derive(new, Debug)]
1857 struct PermuteDimsOps<B: FusionBackend> {
1858 desc: PermuteOpIr,
1859 _b: PhantomData<B>,
1860 }
1861
1862 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
1863 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1864 let input = handles.get_int_tensor::<B>(&self.desc.input);
1865 let output = B::int_permute(input, self.desc.axes.as_slice());
1866 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1867 }
1868 }
1869
1870 let mut streams = OperationStreams::default();
1871 streams.tensor(&tensor);
1872
1873 let shape = tensor.shape.clone().permute(axes).unwrap();
1875 let dtype = tensor.dtype;
1876 let out = tensor.client.tensor_uninitialized(shape, dtype);
1877
1878 let desc = PermuteOpIr {
1879 input: tensor.into_ir(),
1880 axes: axes.to_vec(),
1881 out: out.to_ir_out(),
1882 };
1883
1884 out.client.register(
1885 streams,
1886 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
1887 PermuteDimsOps::<B>::new(desc),
1888 );
1889
1890 out
1891 }
1892
1893 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
1894 #[derive(new, Debug)]
1895 struct ExpandOps<B: FusionBackend> {
1896 desc: ExpandOpIr,
1897 _b: PhantomData<B>,
1898 }
1899
1900 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
1901 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1902 let input = handles.get_int_tensor::<B>(&self.desc.input);
1903 let output = B::int_expand(input, self.desc.shape.clone());
1904 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1905 }
1906 }
1907
1908 let mut streams = OperationStreams::default();
1909 streams.tensor(&tensor);
1910
1911 let dtype = tensor.dtype;
1912 let out = tensor.client.tensor_uninitialized(shape.clone(), dtype);
1913
1914 let desc = ExpandOpIr {
1915 input: tensor.into_ir(),
1916 shape,
1917 out: out.to_ir_out(),
1918 };
1919
1920 out.client.register(
1921 streams,
1922 OperationIr::BaseInt(BaseOperationIr::Expand(desc.clone())),
1923 ExpandOps::<B>::new(desc),
1924 );
1925
1926 out
1927 }
1928
1929 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1930 #[derive(new, Debug)]
1931 struct FlipDimsOps<B: FusionBackend> {
1932 desc: FlipOpIr,
1933 _b: PhantomData<B>,
1934 }
1935
1936 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipDimsOps<B> {
1937 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1938 let input = handles.get_int_tensor::<B>(&self.desc.input);
1939 let axes = &self.desc.axes;
1940 let output = B::int_flip(input, axes);
1941 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1942 }
1943 }
1944
1945 let mut streams = OperationStreams::default();
1946 streams.tensor(&tensor);
1947
1948 let dtype = tensor.dtype;
1949 let out = tensor
1950 .client
1951 .tensor_uninitialized(tensor.shape.clone(), dtype);
1952
1953 let desc = FlipOpIr {
1954 input: tensor.into_ir(),
1955 axes: axes.to_vec(),
1956 out: out.to_ir_out(),
1957 };
1958
1959 out.client.register(
1960 streams,
1961 OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
1962 FlipDimsOps::<B>::new(desc),
1963 );
1964
1965 out
1966 }
1967
1968 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
1969 #[derive(new, Debug)]
1970 struct RepeatDimOps<B: FusionBackend> {
1971 desc: RepeatDimOpIr,
1972 _b: PhantomData<B>,
1973 }
1974
1975 impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1976 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1977 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1978
1979 let output = B::int_repeat_dim(tensor, self.desc.dim, self.desc.times);
1980
1981 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1982 }
1983 }
1984
1985 let dtype = tensor.dtype;
1986 let mut streams = OperationStreams::default();
1987 streams.tensor(&tensor);
1988 let shape = tensor.shape.clone().repeat(dim, times);
1989 let out = tensor.client.tensor_uninitialized(shape, dtype);
1990
1991 let desc = RepeatDimOpIr {
1992 tensor: tensor.into_ir(),
1993 dim,
1994 times,
1995 out: out.to_ir_out(),
1996 };
1997 out.client.register(
1998 streams,
1999 OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc.clone())),
2000 RepeatDimOps::<B>::new(desc),
2001 );
2002
2003 out
2004 }
2005
2006 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2007 binary_int_ops!(BitwiseAndOps, B::bitwise_and);
2008
2009 let dtype = lhs.dtype;
2010 let mut streams = OperationStreams::default();
2011 streams.tensor(&lhs);
2012 streams.tensor(&rhs);
2013 let out = lhs
2014 .client
2015 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2016
2017 let desc = BinaryOpIr {
2018 lhs: lhs.into_ir(),
2019 rhs: rhs.into_ir(),
2020 out: out.to_ir_out(),
2021 };
2022 out.client.register(
2023 streams,
2024 OperationIr::Int(IntOperationIr::BitwiseAnd(desc.clone())),
2025 BitwiseAndOps::<B>::new(desc),
2026 );
2027
2028 out
2029 }
2030
2031 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2032 scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar);
2033
2034 let dtype = lhs.dtype;
2035 let mut streams = OperationStreams::default();
2036 streams.tensor(&lhs);
2037 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2038
2039 let desc = ScalarOpIr {
2040 lhs: lhs.into_ir(),
2041 rhs: ScalarIr::with_dtype(rhs, &dtype),
2042 out: out.to_ir_out(),
2043 };
2044 out.client.register(
2045 streams,
2046 OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc.clone())),
2047 BitwiseAndOps::<B>::new(desc),
2048 );
2049
2050 out
2051 }
2052
2053 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2054 binary_int_ops!(BitwiseOrOps, B::bitwise_or);
2055
2056 let dtype = lhs.dtype;
2057 let mut streams = OperationStreams::default();
2058 streams.tensor(&lhs);
2059 streams.tensor(&rhs);
2060 let out = lhs
2061 .client
2062 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2063
2064 let desc = BinaryOpIr {
2065 lhs: lhs.into_ir(),
2066 rhs: rhs.into_ir(),
2067 out: out.to_ir_out(),
2068 };
2069 out.client.register(
2070 streams,
2071 OperationIr::Int(IntOperationIr::BitwiseOr(desc.clone())),
2072 BitwiseOrOps::<B>::new(desc),
2073 );
2074
2075 out
2076 }
2077
2078 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2079 scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar);
2080
2081 let dtype = lhs.dtype;
2082 let mut streams = OperationStreams::default();
2083 streams.tensor(&lhs);
2084 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2085
2086 let desc = ScalarOpIr {
2087 lhs: lhs.into_ir(),
2088 rhs: ScalarIr::with_dtype(rhs, &dtype),
2089 out: out.to_ir_out(),
2090 };
2091 out.client.register(
2092 streams,
2093 OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc.clone())),
2094 BitwiseOrOps::<B>::new(desc),
2095 );
2096
2097 out
2098 }
2099
2100 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2101 binary_int_ops!(BitwiseXorOps, B::bitwise_xor);
2102
2103 let dtype = lhs.dtype;
2104 let mut streams = OperationStreams::default();
2105 streams.tensor(&lhs);
2106 streams.tensor(&rhs);
2107 let out = lhs
2108 .client
2109 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2110
2111 let desc = BinaryOpIr {
2112 lhs: lhs.into_ir(),
2113 rhs: rhs.into_ir(),
2114 out: out.to_ir_out(),
2115 };
2116 out.client.register(
2117 streams,
2118 OperationIr::Int(IntOperationIr::BitwiseXor(desc.clone())),
2119 BitwiseXorOps::<B>::new(desc),
2120 );
2121
2122 out
2123 }
2124
2125 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2126 scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar);
2127
2128 let dtype = lhs.dtype;
2129 let mut streams = OperationStreams::default();
2130 streams.tensor(&lhs);
2131 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2132
2133 let desc = ScalarOpIr {
2134 lhs: lhs.into_ir(),
2135 rhs: ScalarIr::with_dtype(rhs, &dtype),
2136 out: out.to_ir_out(),
2137 };
2138 out.client.register(
2139 streams,
2140 OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc.clone())),
2141 BitwiseXorOps::<B>::new(desc),
2142 );
2143
2144 out
2145 }
2146
2147 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
2148 unary_int_ops!(BitwiseNotOps, B::bitwise_not);
2149
2150 let dtype = tensor.dtype;
2151 let mut streams = OperationStreams::default();
2152 streams.tensor(&tensor);
2153 let out = tensor
2154 .client
2155 .tensor_uninitialized(tensor.shape.clone(), dtype);
2156
2157 let desc = UnaryOpIr {
2158 input: tensor.into_ir(),
2159 out: out.to_ir_out(),
2160 };
2161 out.client.register(
2162 streams,
2163 OperationIr::Int(IntOperationIr::BitwiseNot(desc.clone())),
2164 BitwiseNotOps::<B>::new(desc),
2165 );
2166
2167 out
2168 }
2169
2170 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2171 binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift);
2172
2173 let dtype = lhs.dtype;
2174 let mut streams = OperationStreams::default();
2175 streams.tensor(&lhs);
2176 streams.tensor(&rhs);
2177 let out = lhs
2178 .client
2179 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2180
2181 let desc = BinaryOpIr {
2182 lhs: lhs.into_ir(),
2183 rhs: rhs.into_ir(),
2184 out: out.to_ir_out(),
2185 };
2186 out.client.register(
2187 streams,
2188 OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc.clone())),
2189 BitwiseLeftShiftOps::<B>::new(desc),
2190 );
2191
2192 out
2193 }
2194
2195 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2196 scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar);
2197
2198 let dtype = lhs.dtype;
2199 let mut streams = OperationStreams::default();
2200 streams.tensor(&lhs);
2201 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2202
2203 let desc = ScalarOpIr {
2204 lhs: lhs.into_ir(),
2205 rhs: ScalarIr::with_dtype(rhs, &dtype),
2206 out: out.to_ir_out(),
2207 };
2208 out.client.register(
2209 streams,
2210 OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(desc.clone())),
2211 BitwiseLeftShiftOps::<B>::new(desc),
2212 );
2213
2214 out
2215 }
2216
2217 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2218 binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift);
2219
2220 let dtype = lhs.dtype;
2221 let mut streams = OperationStreams::default();
2222 streams.tensor(&lhs);
2223 streams.tensor(&rhs);
2224 let out = lhs
2225 .client
2226 .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2227
2228 let desc = BinaryOpIr {
2229 lhs: lhs.into_ir(),
2230 rhs: rhs.into_ir(),
2231 out: out.to_ir_out(),
2232 };
2233 out.client.register(
2234 streams,
2235 OperationIr::Int(IntOperationIr::BitwiseRightShift(desc.clone())),
2236 BitwiseRightShiftOps::<B>::new(desc),
2237 );
2238
2239 out
2240 }
2241
2242 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2243 scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar);
2244
2245 let dtype = lhs.dtype;
2246 let mut streams = OperationStreams::default();
2247 streams.tensor(&lhs);
2248 let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2249
2250 let desc = ScalarOpIr {
2251 lhs: lhs.into_ir(),
2252 rhs: ScalarIr::with_dtype(rhs, &dtype),
2253 out: out.to_ir_out(),
2254 };
2255 out.client.register(
2256 streams,
2257 OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(desc.clone())),
2258 BitwiseRightShiftOps::<B>::new(desc),
2259 );
2260
2261 out
2262 }
2263
2264 fn int_cast(tensor: IntTensor<Self>, dtype: burn_tensor::IntDType) -> IntTensor<Self> {
2265 #[derive(new, Debug)]
2266 struct CastOps<B: FusionBackend> {
2267 desc: UnaryOpIr,
2268 dtype: burn_tensor::IntDType,
2269 _b: PhantomData<B>,
2270 }
2271
2272 impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2273 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2274 let input = handles.get_int_tensor::<B>(&self.desc.input);
2275 let output: B::IntTensorPrimitive = B::int_cast(input, self.dtype);
2276 handles.register_int_tensor::<B>(&self.desc.out.id, output);
2277 }
2278 }
2279
2280 let mut streams = OperationStreams::default();
2281 streams.tensor(&tensor);
2282 let out = tensor
2283 .client
2284 .tensor_uninitialized(tensor.shape.clone(), dtype.into());
2285
2286 let desc = UnaryOpIr {
2287 input: tensor.into_ir(),
2288 out: out.to_ir_out(),
2289 };
2290 out.client.register(
2291 streams,
2292 OperationIr::BaseInt(BaseOperationIr::Cast(desc.clone())),
2293 CastOps::<B>::new(desc, dtype),
2294 );
2295
2296 out
2297 }
2298
2299 fn int_unfold(
2300 tensor: IntTensor<Self>,
2301 dim: usize,
2302 size: usize,
2303 step: usize,
2304 ) -> IntTensor<Self> {
2305 #[derive(new, Debug)]
2306 struct UnfoldOps<B: FusionBackend> {
2307 desc: UnfoldOpIr,
2308 _b: PhantomData<B>,
2309 }
2310
2311 impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2312 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2313 let input = handles.get_int_tensor::<B>(&self.desc.input);
2314 let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2315
2316 handles.register_int_tensor::<B>(&self.desc.out.id, output);
2317 }
2318 }
2319
2320 let mut streams = OperationStreams::default();
2321 streams.tensor(&tensor);
2322
2323 let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
2324 let out = tensor
2325 .client
2326 .tensor_uninitialized(Shape::from(shape), tensor.dtype);
2327
2328 let desc = UnfoldOpIr {
2329 input: tensor.into_ir(),
2330 out: out.to_ir_out(),
2331 dim,
2332 size,
2333 step,
2334 };
2335
2336 out.client.register(
2337 streams,
2338 OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())),
2339 UnfoldOps::<B>::new(desc),
2340 );
2341
2342 out
2343 }
2344}