1#![allow(missing_docs)]
2
3use alloc::vec::Vec;
4use burn_backend::{
5 DType, Distribution, Shape, Slice, calculate_matmul_output,
6 ops::{
7 conv::{
8 calculate_conv_output_shape, calculate_conv_transpose_output_shape,
9 calculate_pool_output_shape,
10 },
11 unfold::calculate_unfold_shape,
12 },
13 quantization::QuantScheme,
14 tensor::IndexingUpdateOp,
15};
16
17use crate::{ScalarIr, TensorId, TensorIr};
18
19use super::operation::*;
20
21impl CreationOpIr {
22 pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
23 let out = TensorIr::uninit(new_id(), shape, dtype);
24
25 CreationOpIr { out }
26 }
27}
28
29impl InitOperationIr {
30 pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
31 let out = TensorIr::uninit(new_id(), shape, dtype);
32
33 InitOperationIr { out }
34 }
35}
36
37impl RandomOpIr {
38 pub fn create(
39 shape: Shape,
40 dtype: DType,
41 distribution: Distribution,
42 new_id: impl FnOnce() -> TensorId,
43 ) -> Self {
44 let out = TensorIr::uninit(new_id(), shape, dtype);
45
46 RandomOpIr { out, distribution }
47 }
48}
49
50impl FullOpIr {
51 pub fn create(
52 shape: Shape,
53 dtype: DType,
54 value: ScalarIr,
55 new_id: impl FnOnce() -> TensorId,
56 ) -> Self {
57 let out = TensorIr::uninit(new_id(), shape, dtype);
58
59 FullOpIr { out, value }
60 }
61}
62
63impl CastOpIr {
64 pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
65 let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype);
66 CastOpIr { input, out }
67 }
68}
69
70impl ShapeOpIr {
71 pub fn expand(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {
72 let shape = input.shape.expand(shape).unwrap();
73 Self::create(input, shape, new_id)
74 }
75
76 pub fn reshape(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {
77 let shape = input.shape.reshape(shape).unwrap();
78 Self::create(input, shape, new_id)
79 }
80
81 fn create(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {
82 let out = TensorIr::uninit(new_id(), shape, input.dtype);
83 ShapeOpIr { input, out }
84 }
85}
86
87impl From<MatmulOpIr> for BinaryOpIr {
90 fn from(value: MatmulOpIr) -> Self {
91 Self {
92 lhs: value.lhs,
93 rhs: value.rhs,
94 out: value.out,
95 }
96 }
97}
98
99impl From<ReduceOpIr> for UnaryOpIr {
100 fn from(value: ReduceOpIr) -> Self {
101 Self {
102 input: value.input,
103 out: value.out,
104 }
105 }
106}
107
108#[derive(Debug)]
109#[allow(missing_docs)]
110pub enum IrError {
111 DTypeMismatch,
112}
113
114fn dtype_compat(lhs: &DType, rhs: &DType) -> bool {
115 let lhs_qfloat = matches!(lhs, DType::QFloat(_));
116 let rhs_qfloat = matches!(rhs, DType::QFloat(_));
117 if lhs_qfloat && (rhs_qfloat || rhs.is_float())
118 || lhs.is_float() && (rhs_qfloat || rhs.is_float())
119 {
120 true
121 } else {
122 lhs == rhs
123 }
124}
125
126fn output_check<'a, I>(inputs: I, compat: impl Fn(&DType, &DType) -> bool) -> Result<DType, IrError>
127where
128 I: IntoIterator<Item = &'a DType>,
129{
130 let mut iter = inputs.into_iter();
131 let first = iter.next().unwrap();
132 for d in iter {
133 if !compat(first, d) {
134 return Err(IrError::DTypeMismatch);
135 }
136 }
137 Ok(*first)
138}
139
140fn output_dtype<'a, I: IntoIterator<Item = &'a DType>>(inputs: I) -> Result<DType, IrError> {
141 output_check(inputs, |a, b| a == b)
142}
143
144fn output_dtype_mixed<'a, I: IntoIterator<Item = &'a DType>>(inputs: I) -> Result<DType, IrError> {
145 output_check(inputs, dtype_compat)
146}
147
148macro_rules! impl_ir_create {
152 (@create_fn $op:ident { $( $field:ident : $ty:ty ),* $(,)? } , $shape:expr, $dtype:expr) => {
153 #[doc = "Create a new operation IR from the given inputs."]
154 #[doc = "`new_id` should generate a unique `TensorId` for the uninitialized output tensor."]
155 #[allow(clippy::too_many_arguments)]
156 pub fn create($( $field : $ty ),*, new_id: impl FnOnce() -> crate::TensorId) -> $op {
157 let shape = $shape;
158 let dtype = $dtype;
159 let out = TensorIr::uninit(new_id(), shape, dtype);
160 $op { $( $field ),*, out }
161 }
162 };
163
164 (
166 $op:ident { $( $field:ident : $ty:ty ),* $(,)? },
167 shape = $shape:expr,
168 dtype = $dtype:expr
169 ) => {
170 impl $op {
171 impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype);
172 }
173 };
174
175 (
177 $op:ident { $( $field:ident : $ty:ty ),* $(,)? },
178 shape = $shape:expr,
179 dtype = $dtype:expr,
180 $fn_name:ident ( $extra:ident : $extra_ty:ty )
181 ) => {
182 impl $op {
183 impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype);
184
185 #[doc = "Create a new operation IR from the given inputs and the given output dtype."]
186 #[allow(clippy::too_many_arguments)]
187 pub fn $fn_name($( $field : $ty ),*, $extra: $extra_ty, new_id: impl FnOnce() -> crate::TensorId) -> Self {
188 let shape = $shape;
189 let _ = $dtype; let out = TensorIr::uninit(new_id(), shape, $extra);
191 $op { $( $field ),*, out }
192 }
193 }
194 };
195}
196
197impl_ir_create!(
198 UnaryOpIr { input: TensorIr },
199 shape = input.shape.clone(),
200 dtype = input.dtype,
201 create_comparison(bool_dtype: DType)
203);
204
205impl_ir_create!(
206 BinaryOpIr {
207 lhs: TensorIr,
208 rhs: TensorIr
209 },
210 shape = lhs.shape.broadcast(&rhs.shape).unwrap(),
211 dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap(),
212 create_comparison(bool_dtype: DType)
214);
215
216impl_ir_create!(
217 ScalarOpIr {
218 lhs: TensorIr,
219 rhs: ScalarIr
220 },
221 shape = lhs.shape.clone(),
222 dtype = lhs.dtype,
223 create_comparison(bool_dtype: DType)
225);
226
227impl_ir_create!(
228 MatmulOpIr {
229 lhs: TensorIr,
230 rhs: TensorIr
231 },
232 shape = calculate_matmul_output(&lhs.shape, &rhs.shape).unwrap(),
233 dtype = output_dtype_mixed([&lhs.dtype, &rhs.dtype]).unwrap(),
234 create_mixed(out_dtype: DType)
236);
237
238impl_ir_create!(
239 SwapDimsOpIr {
240 input: TensorIr,
241 dim1: usize,
242 dim2: usize
243 },
244 shape = input.shape.clone().swap(dim1, dim2).unwrap(),
245 dtype = input.dtype
246);
247
248impl_ir_create!(
249 PermuteOpIr { input: TensorIr, axes: Vec<usize> },
250 shape = input.shape.clone().permute(&axes).unwrap(),
251 dtype = input.dtype
252);
253
254impl_ir_create!(
255 RepeatDimOpIr {
256 tensor: TensorIr,
257 dim: usize,
258 times: usize
259 },
260 shape = tensor.shape.clone().repeat(dim, times).unwrap(),
261 dtype = tensor.dtype
262);
263
264impl_ir_create!(
265 FlipOpIr { input: TensorIr, axes: Vec<usize> },
266 shape = input.shape.clone(), dtype = input.dtype
268);
269
270impl_ir_create!(
271 CatOpIr { tensors: Vec<TensorIr>, dim: usize },
272 shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap(),
273 dtype = output_dtype(tensors.iter().map(|t| &t.dtype)).unwrap()
274);
275
276impl_ir_create!(
277 GatherOpIr {
278 tensor: TensorIr,
279 dim: usize,
280 indices: TensorIr
281 },
282 shape = indices.shape.clone(), dtype = tensor.dtype
284);
285
286impl_ir_create!(
287 ScatterOpIr {
288 tensor: TensorIr,
289 dim: usize,
290 indices: TensorIr,
291 value: TensorIr,
292 update: IndexingUpdateOp
293 },
294 shape = tensor.shape.clone(), dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
296);
297
298impl_ir_create!(
299 ReduceOpIr { input: TensorIr },
300 shape = [1].into(),
301 dtype = input.dtype
302);
303
304impl_ir_create!(
305 ReduceDimOpIr {
306 input: TensorIr,
307 axis: usize
308 },
309 shape = input.shape.clone().reduce(axis).unwrap(),
310 dtype = input.dtype,
311 create_arg(ind_dtype: DType)
313);
314
315impl_ir_create!(
316 DimOpIr {
317 input: TensorIr,
318 axis: usize
319 },
320 shape = input.shape.clone(), dtype = input.dtype
322);
323
324impl_ir_create!(
325 SelectOpIr {
326 tensor: TensorIr,
327 dim: usize,
328 indices: TensorIr
329 },
330 shape = {
332 let mut s = tensor.shape.clone();
333 s[dim] = indices.shape[0];
334 s
335 },
336 dtype = tensor.dtype
337);
338
339impl_ir_create!(
340 SelectAssignOpIr {
341 tensor: TensorIr,
342 dim: usize,
343 indices: TensorIr,
344 value: TensorIr,
345 update: IndexingUpdateOp
346 },
347 shape = tensor.shape.clone(),
349 dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
350);
351
352impl_ir_create!(
353 SliceOpIr {
354 tensor: TensorIr,
355 ranges: Vec<Slice>,
356 },
357 shape = tensor.shape.clone().slice(&ranges).unwrap(),
358 dtype = tensor.dtype
359);
360
361impl_ir_create!(
362 SliceAssignOpIr {
363 tensor: TensorIr,
364 ranges: Vec<Slice>,
365 value: TensorIr
366 },
367 shape = tensor.shape.clone(),
369 dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
370);
371
372impl_ir_create!(
373 MaskWhereOpIr {
374 tensor: TensorIr,
375 mask: TensorIr,
376 value: TensorIr
377 },
378 shape = Shape::broadcast_many([&tensor.shape, &mask.shape, &value.shape]).unwrap(),
379 dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
380);
381
382impl_ir_create!(
383 MaskFillOpIr {
384 tensor: TensorIr,
385 mask: TensorIr,
386 value: ScalarIr
387 },
388 shape = tensor.shape.broadcast(&mask.shape).unwrap(),
389 dtype = tensor.dtype
390);
391
392impl_ir_create!(
393 ClampOpIr {
394 tensor: TensorIr,
395 min: ScalarIr,
396 max: ScalarIr
397 },
398 shape = tensor.shape.clone(),
399 dtype = tensor.dtype
400);
401
402impl_ir_create!(
403 AvgPool1dOpIr {
404 x: TensorIr,
405 kernel_size: usize,
406 stride: usize,
407 padding: usize,
408 count_include_pad: bool,
409 ceil_mode: bool
410 },
411 shape = calculate_pool_output_shape(
412 &x.shape,
413 &[kernel_size],
414 &[stride],
415 &[padding],
416 &[1],
417 ceil_mode
418 )
419 .unwrap(),
420 dtype = x.dtype
421);
422
423impl_ir_create!(
424 AvgPool1dBackwardOpIr {
425 x: TensorIr,
426 grad: TensorIr,
427 kernel_size: usize,
428 stride: usize,
429 padding: usize,
430 count_include_pad: bool,
431 ceil_mode: bool
432 },
433 shape = x.shape.clone(),
434 dtype = x.dtype
435);
436
437impl_ir_create!(
438 AvgPool2dOpIr {
439 x: TensorIr,
440 kernel_size: [usize; 2],
441 stride: [usize; 2],
442 padding: [usize; 2],
443 count_include_pad: bool,
444 ceil_mode: bool
445 },
446 shape = calculate_pool_output_shape(
447 &x.shape,
448 &kernel_size,
449 &stride,
450 &padding,
451 &[1, 1],
452 ceil_mode
453 )
454 .unwrap(),
455 dtype = x.dtype
456);
457
458impl_ir_create!(
459 AvgPool2dBackwardOpIr {
460 x: TensorIr,
461 grad: TensorIr,
462 kernel_size: [usize; 2],
463 stride: [usize; 2],
464 padding: [usize; 2],
465 count_include_pad: bool,
466 ceil_mode: bool
467 },
468 shape = x.shape.clone(),
469 dtype = x.dtype
470);
471
472impl_ir_create!(
473 MaxPool1dOpIr {
474 x: TensorIr,
475 kernel_size: usize,
476 stride: usize,
477 padding: usize,
478 dilation: usize,
479 ceil_mode: bool
480 },
481 shape = calculate_pool_output_shape(
482 &x.shape,
483 &[kernel_size],
484 &[stride],
485 &[padding],
486 &[dilation],
487 ceil_mode
488 )
489 .unwrap(),
490 dtype = x.dtype
491);
492
493impl_ir_create!(
494 MaxPool2dOpIr {
495 x: TensorIr,
496 kernel_size: [usize; 2],
497 stride: [usize; 2],
498 padding: [usize; 2],
499 dilation: [usize; 2],
500 ceil_mode: bool
501 },
502 shape = calculate_pool_output_shape(
503 &x.shape,
504 &kernel_size,
505 &stride,
506 &padding,
507 &dilation,
508 ceil_mode
509 )
510 .unwrap(),
511 dtype = x.dtype
512);
513
514impl_ir_create!(
515 MaxPool1dWithIndicesBackwardOpIr {
516 x: TensorIr,
517 grad: TensorIr,
518 indices: TensorIr,
519 kernel_size: usize,
520 stride: usize,
521 padding: usize,
522 dilation: usize,
523 ceil_mode: bool
524 },
525 shape = x.shape.clone(),
526 dtype = x.dtype
527);
528
529impl_ir_create!(
530 MaxPool2dWithIndicesBackwardOpIr {
531 x: TensorIr,
532 grad: TensorIr,
533 indices: TensorIr,
534 kernel_size: [usize; 2],
535 stride: [usize; 2],
536 padding: [usize; 2],
537 dilation: [usize; 2],
538 ceil_mode: bool
539 },
540 shape = x.shape.clone(),
541 dtype = x.dtype
542);
543
544impl_ir_create!(
545 AdaptiveAvgPool1dOpIr {
546 x: TensorIr,
547 output_size: usize
548 },
549 shape = Shape::new([x.shape[0], x.shape[1], output_size]),
550 dtype = x.dtype
551);
552
553impl_ir_create!(
554 AdaptiveAvgPool2dOpIr {
555 x: TensorIr,
556 output_size: [usize; 2]
557 },
558 shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]),
559 dtype = x.dtype
560);
561
562impl_ir_create!(
563 AdaptiveAvgPool1dBackwardOpIr {
564 x: TensorIr,
565 grad: TensorIr,
566 },
567 shape = x.shape.clone(),
568 dtype = x.dtype
569);
570
571impl_ir_create!(
572 AdaptiveAvgPool2dBackwardOpIr {
573 x: TensorIr,
574 grad: TensorIr,
575 },
576 shape = x.shape.clone(),
577 dtype = x.dtype
578);
579
580impl_ir_create!(
581 InterpolateOpIr {
582 x: TensorIr,
583 output_size: [usize; 2],
584 options: InterpolateOptionsIr
585 },
586 shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]),
587 dtype = x.dtype
588);
589
590impl_ir_create!(
591 InterpolateBackwardOpIr {
592 x: TensorIr,
593 grad: TensorIr,
594 output_size: [usize; 2],
595 options: InterpolateOptionsIr
596 },
597 shape = x.shape.clone(),
598 dtype = x.dtype
599);
600
601impl_ir_create!(
602 GridSample2dOpIr {
603 tensor: TensorIr,
604 grid: TensorIr,
605 options: GridSampleOptionsIr
606 },
607 shape = Shape::new([
611 tensor.shape[0],
612 tensor.shape[1],
613 grid.shape[1],
614 grid.shape[2]
615 ]),
616 dtype = tensor.dtype
617);
618
619impl_ir_create!(
620 Conv1dOpIr {
621 x: TensorIr,
622 weight: TensorIr,
623 bias: Option<TensorIr>,
624 options: Conv1dOptionsIr
625 },
626 shape = calculate_conv_output_shape(
627 &x.shape,
628 &weight.shape,
629 &options.stride,
630 &options.padding,
631 &options.dilation,
632 )
633 .unwrap(),
634 dtype = output_dtype(
635 [
636 Some(&x.dtype),
637 Some(&weight.dtype),
638 bias.as_ref().map(|b| &b.dtype),
639 ]
640 .iter()
641 .filter_map(|&d| d),
642 )
643 .unwrap()
644);
645
646impl_ir_create!(
647 Conv2dOpIr {
648 x: TensorIr,
649 weight: TensorIr,
650 bias: Option<TensorIr>,
651 options: Conv2dOptionsIr
652 },
653 shape = calculate_conv_output_shape(
654 &x.shape,
655 &weight.shape,
656 &options.stride,
657 &options.padding,
658 &options.dilation,
659 )
660 .unwrap(),
661 dtype = output_dtype(
662 [
663 Some(&x.dtype),
664 Some(&weight.dtype),
665 bias.as_ref().map(|b| &b.dtype),
666 ]
667 .iter()
668 .filter_map(|&d| d),
669 )
670 .unwrap()
671);
672
673impl_ir_create!(
674 Conv3dOpIr {
675 x: TensorIr,
676 weight: TensorIr,
677 bias: Option<TensorIr>,
678 options: Conv3dOptionsIr
679 },
680 shape = calculate_conv_output_shape(
681 &x.shape,
682 &weight.shape,
683 &options.stride,
684 &options.padding,
685 &options.dilation,
686 )
687 .unwrap(),
688 dtype = output_dtype(
689 [
690 Some(&x.dtype),
691 Some(&weight.dtype),
692 bias.as_ref().map(|b| &b.dtype),
693 ]
694 .iter()
695 .filter_map(|&d| d),
696 )
697 .unwrap()
698);
699
700impl_ir_create!(
701 DeformConv2dOpIr {
702 x: TensorIr,
703 offset: TensorIr,
704 weight: TensorIr,
705 mask: Option<TensorIr>,
706 bias: Option<TensorIr>,
707 options: DeformableConv2dOptionsIr
708 },
709 shape = calculate_conv_output_shape(
710 &x.shape,
711 &weight.shape,
712 &options.stride,
713 &options.padding,
714 &options.dilation,
715 )
716 .unwrap(),
717 dtype = output_dtype(
718 [
719 Some(&x.dtype),
720 Some(&offset.dtype),
721 Some(&weight.dtype),
722 mask.as_ref().map(|m| &m.dtype),
723 bias.as_ref().map(|b| &b.dtype),
724 ]
725 .iter()
726 .filter_map(|&d| d),
727 )
728 .unwrap()
729);
730
731impl_ir_create!(
732 ConvTranspose1dOpIr {
733 x: TensorIr,
734 weight: TensorIr,
735 bias: Option<TensorIr>,
736 options: ConvTranspose1dOptionsIr
737 },
738 shape = calculate_conv_transpose_output_shape(
739 &x.shape,
740 &weight.shape,
741 &options.stride,
742 &options.padding,
743 &options.padding_out,
744 &options.dilation,
745 options.groups,
746 )
747 .unwrap(),
748 dtype = output_dtype(
749 [
750 Some(&x.dtype),
751 Some(&weight.dtype),
752 bias.as_ref().map(|b| &b.dtype),
753 ]
754 .iter()
755 .filter_map(|&d| d),
756 )
757 .unwrap()
758);
759
760impl_ir_create!(
761 ConvTranspose2dOpIr {
762 x: TensorIr,
763 weight: TensorIr,
764 bias: Option<TensorIr>,
765 options: ConvTranspose2dOptionsIr
766 },
767 shape = calculate_conv_transpose_output_shape(
768 &x.shape,
769 &weight.shape,
770 &options.stride,
771 &options.padding,
772 &options.padding_out,
773 &options.dilation,
774 options.groups,
775 )
776 .unwrap(),
777 dtype = output_dtype(
778 [
779 Some(&x.dtype),
780 Some(&weight.dtype),
781 bias.as_ref().map(|b| &b.dtype),
782 ]
783 .iter()
784 .filter_map(|&d| d),
785 )
786 .unwrap()
787);
788
789impl_ir_create!(
790 ConvTranspose3dOpIr {
791 x: TensorIr,
792 weight: TensorIr,
793 bias: Option<TensorIr>,
794 options: ConvTranspose3dOptionsIr
795 },
796 shape = calculate_conv_transpose_output_shape(
797 &x.shape,
798 &weight.shape,
799 &options.stride,
800 &options.padding,
801 &options.padding_out,
802 &options.dilation,
803 options.groups,
804 )
805 .unwrap(),
806 dtype = output_dtype(
807 [
808 Some(&x.dtype),
809 Some(&weight.dtype),
810 bias.as_ref().map(|b| &b.dtype),
811 ]
812 .iter()
813 .filter_map(|&d| d),
814 )
815 .unwrap()
816);
817
818impl_ir_create!(
819 UnfoldOpIr {
820 input: TensorIr,
821 dim: usize,
822 size: usize,
823 step: usize
824 },
825 shape = calculate_unfold_shape(input.shape.clone(), dim, size, step),
826 dtype = input.dtype
827);
828
829impl_ir_create!(
830 CrossOpIr {
831 lhs: TensorIr,
832 rhs: TensorIr,
833 dim: usize
834 },
835 shape = lhs.shape.broadcast(&rhs.shape).unwrap(),
836 dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap()
837);
838
839impl_ir_create!(
840 QuantizeOpIr {
841 tensor: TensorIr,
842 qparams: QuantizationParametersIr,
843 scheme: QuantScheme
844 },
845 shape = tensor.shape.clone(),
846 dtype = DType::QFloat(scheme)
847);
848
849impl DequantizeOpIr {
850 pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
851 let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype);
852
853 DequantizeOpIr { input, out }
854 }
855}
856
857impl ReduceDimWithIndicesOpIr {
860 pub fn create(
861 tensor: TensorIr,
862 dim: usize,
863 dtype_indices: DType,
864 mut new_id: impl FnMut() -> TensorId,
865 ) -> Self {
866 let mut shape = tensor.shape.clone();
867 shape[dim] = 1;
868 let out = TensorIr::uninit(new_id(), shape.clone(), tensor.dtype);
869 let out_indices = TensorIr::uninit(new_id(), shape.clone(), dtype_indices);
870
871 ReduceDimWithIndicesOpIr {
872 tensor,
873 dim,
874 out,
875 out_indices,
876 }
877 }
878}
879
880impl DeformConv2dBackwardOpIr {
881 #[allow(clippy::too_many_arguments)]
882 pub fn create(
883 x: TensorIr,
884 offset: TensorIr,
885 weight: TensorIr,
886 mask: Option<TensorIr>,
887 bias: Option<TensorIr>,
888 out_grad: TensorIr,
889 options: DeformableConv2dOptionsIr,
890 mut new_id: impl FnMut() -> TensorId,
891 ) -> Self {
892 let dtype = output_dtype(
893 [
894 Some(&x.dtype),
895 Some(&weight.dtype),
896 mask.as_ref().map(|m| &m.dtype),
897 bias.as_ref().map(|b| &b.dtype),
898 ]
899 .iter()
900 .filter_map(|&d| d),
901 )
902 .unwrap();
903
904 let input_grad = TensorIr::uninit(new_id(), x.shape.clone(), dtype);
905 let offset_grad = TensorIr::uninit(new_id(), offset.shape.clone(), dtype);
906 let weight_grad = TensorIr::uninit(new_id(), weight.shape.clone(), dtype);
907 let mask_grad = mask
908 .as_ref()
909 .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype));
910 let bias_grad = bias
911 .as_ref()
912 .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype));
913
914 DeformConv2dBackwardOpIr {
915 x,
916 offset,
917 weight,
918 mask,
919 bias,
920 out_grad,
921 options,
922 input_grad,
923 offset_grad,
924 weight_grad,
925 mask_grad,
926 bias_grad,
927 }
928 }
929}
930
931impl MaxPool1dWithIndicesOpIr {
932 #[allow(clippy::too_many_arguments)]
933 pub fn create(
934 x: TensorIr,
935 kernel_size: usize,
936 stride: usize,
937 padding: usize,
938 dilation: usize,
939 ceil_mode: bool,
940 dtype_indices: DType,
941 mut new_id: impl FnMut() -> TensorId,
942 ) -> Self {
943 let shape = calculate_pool_output_shape(
944 &x.shape,
945 &[kernel_size],
946 &[stride],
947 &[padding],
948 &[dilation],
949 ceil_mode,
950 )
951 .unwrap();
952 let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype);
953 let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices);
954
955 MaxPool1dWithIndicesOpIr {
956 x,
957 kernel_size,
958 stride,
959 padding,
960 dilation,
961 ceil_mode,
962 out,
963 out_indices,
964 }
965 }
966}
967
968impl MaxPool2dWithIndicesOpIr {
969 #[allow(clippy::too_many_arguments)]
970 pub fn create(
971 x: TensorIr,
972 kernel_size: [usize; 2],
973 stride: [usize; 2],
974 padding: [usize; 2],
975 dilation: [usize; 2],
976 ceil_mode: bool,
977 dtype_indices: DType,
978 mut new_id: impl FnMut() -> TensorId,
979 ) -> Self {
980 let shape = calculate_pool_output_shape(
981 &x.shape,
982 &kernel_size,
983 &stride,
984 &padding,
985 &dilation,
986 ceil_mode,
987 )
988 .unwrap();
989 let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype);
990 let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices);
991
992 MaxPool2dWithIndicesOpIr {
993 x,
994 kernel_size,
995 stride,
996 padding,
997 dilation,
998 ceil_mode,
999 out,
1000 out_indices,
1001 }
1002 }
1003}