1use burn_tensor::Shape;
2use core::hash::Hash;
3use serde::{Deserialize, Serialize};
4
5use alloc::borrow::ToOwned;
6use alloc::boxed::Box;
7use alloc::{string::String, vec, vec::Vec};
8
9use burn_tensor::{
10 DType, Distribution, Slice,
11 ops::{
12 ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
13 },
14 quantization::QuantScheme,
15};
16
17use crate::{ScalarIr, TensorId, TensorIr, TensorStatus};
18
19#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
21pub struct CustomOpIr {
22 pub id: String,
24 pub inputs: Vec<TensorIr>,
26 pub outputs: Vec<TensorIr>,
28}
29
30impl CustomOpIr {
31 pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {
33 Self {
34 id: id.to_owned(),
35 inputs: inputs.to_vec(),
36 outputs: outputs.to_vec(),
37 }
38 }
39
40 pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(
42 &self,
43 ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {
44 (
45 self.inputs.as_slice().try_into().expect(
46 "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
47 ),
48 self.outputs.as_slice().try_into().expect(
49 "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
50 ),
51 )
52 }
53
54 fn nodes(&self) -> Vec<&TensorIr> {
55 self.inputs.iter().chain(self.outputs.iter()).collect()
56 }
57}
58
59#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
61pub enum OperationIr {
62 BaseFloat(BaseOperationIr),
64 BaseInt(BaseOperationIr),
66 BaseBool(BaseOperationIr),
68 NumericFloat(DType, NumericOperationIr),
70 NumericInt(DType, NumericOperationIr),
72 Bool(BoolOperationIr),
74 Int(IntOperationIr),
76 Float(DType, FloatOperationIr),
78 Module(ModuleOperationIr),
80 Init(InitOperationIr),
82 Custom(CustomOpIr),
84 Drop(TensorIr),
86}
87
88#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
90pub enum FloatOperationIr {
91 Exp(UnaryOpIr),
93 Log(UnaryOpIr),
95 Log1p(UnaryOpIr),
97 Erf(UnaryOpIr),
99 PowfScalar(ScalarOpIr),
101 Sqrt(UnaryOpIr),
103 Cos(UnaryOpIr),
105 Sin(UnaryOpIr),
107 Tanh(UnaryOpIr),
109 Round(UnaryOpIr),
111 Floor(UnaryOpIr),
113 Ceil(UnaryOpIr),
115 Trunc(UnaryOpIr),
117 IntoInt(UnaryOpIr),
119 Matmul(BinaryOpIr),
121 Cross(CrossOpIr),
123 Random(RandomOpIr),
125 Recip(UnaryOpIr),
127 IsNan(UnaryOpIr),
129 IsInf(UnaryOpIr),
131 Quantize(QuantizeOpIr),
133 Dequantize(DequantizeOpIr),
135}
136
137#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
139pub enum ModuleOperationIr {
140 Embedding(EmbeddingOpIr),
142 EmbeddingBackward(EmbeddingBackwardOpIr),
144 Conv1d(Conv1dOpIr),
146 Conv2d(Conv2dOpIr),
148 Conv3d(Conv3dOpIr),
150 DeformableConv2d(Box<DeformConv2dOpIr>),
152 DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),
154 ConvTranspose1d(ConvTranspose1dOpIr),
156 ConvTranspose2d(ConvTranspose2dOpIr),
158 ConvTranspose3d(ConvTranspose3dOpIr),
160 AvgPool1d(AvgPool1dOpIr),
162 AvgPool2d(AvgPool2dOpIr),
164 AvgPool1dBackward(AvgPool1dBackwardOpIr),
167 AvgPool2dBackward(AvgPool2dBackwardOpIr),
170 AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),
173 AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),
176 AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),
179 AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),
182 MaxPool1d(MaxPool1dOpIr),
185 MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),
188 MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),
191 MaxPool2d(MaxPool2dOpIr),
194 MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),
197 MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),
200 Interpolate(InterpolateOpIr),
202 InterpolateBackward(InterpolateBackwardOpIr),
204}
205
206#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
208pub enum BaseOperationIr {
209 ToDevice(TensorIr),
215 Reshape(UnaryOpIr),
221
222 SwapDims(SwapDimsOpIr),
228
229 Permute(PermuteOpIr),
235
236 Flip(FlipOpIr),
241
242 Expand(ExpandOpIr),
248
249 Unfold(UnfoldOpIr),
252
253 Slice(SliceOpIr),
259 SliceAssign(SliceAssignOpIr),
265 Equal(BinaryOpIr),
271 RepeatDim(RepeatDimOpIr),
277 Cat(CatOpIr),
283 Cast(UnaryOpIr),
285
286 CumSum(DimOpIr),
291
292 CumProd(DimOpIr),
297 CumMin(DimOpIr),
300
301 CumMax(DimOpIr),
306 Empty(TensorIr),
312}
313
314#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
316pub enum NumericOperationIr {
317 Add(BinaryOpIr),
322 AddScalar(ScalarOpIr),
327 Sub(BinaryOpIr),
332 SubScalar(ScalarOpIr),
337 Div(BinaryOpIr),
342 DivScalar(ScalarOpIr),
347 Rem(BinaryOpIr),
352 RemScalar(ScalarOpIr),
357 Mul(BinaryOpIr),
362 MulScalar(ScalarOpIr),
367 Abs(UnaryOpIr),
372 Ones(TensorIr),
377 Zeros(TensorIr),
382 Full((TensorIr, ScalarIr)),
387 Gather(GatherOpIr),
392 Scatter(ScatterOpIr),
397 Select(SelectOpIr),
402 SelectAssign(SelectAssignOpIr),
407 MaskWhere(MaskWhereOpIr),
412 MaskFill(MaskFillOpIr),
417 MeanDim(ReduceDimOpIr),
422 Mean(UnaryOpIr),
427 Sum(UnaryOpIr),
432 SumDim(ReduceDimOpIr),
437
438 Prod(UnaryOpIr),
443
444 ProdDim(ReduceDimOpIr),
449
450 EqualElem(ScalarOpIr),
455 Greater(BinaryOpIr),
460 GreaterElem(ScalarOpIr),
465 GreaterEqual(BinaryOpIr),
470 GreaterEqualElem(ScalarOpIr),
475 Lower(BinaryOpIr),
480 LowerElem(ScalarOpIr),
485 LowerEqual(BinaryOpIr),
490 LowerEqualElem(ScalarOpIr),
495 ArgMax(ReduceDimOpIr),
500 ArgMin(ReduceDimOpIr),
505 Max(UnaryOpIr),
510 MaxDimWithIndices(ReduceDimWithIndicesOpIr),
515 MinDimWithIndices(ReduceDimWithIndicesOpIr),
520 Min(UnaryOpIr),
525 MaxDim(ReduceDimOpIr),
530 MinDim(ReduceDimOpIr),
535 MaxAbs(UnaryOpIr),
540 MaxAbsDim(ReduceDimOpIr),
545 Clamp(ClampOpIr),
550 IntRandom(RandomOpIr),
554 Powf(BinaryOpIr),
559}
560
561#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
563pub enum IntOperationIr {
564 IntoFloat(UnaryOpIr),
566 BitwiseAnd(BinaryOpIr),
570 BitwiseAndScalar(ScalarOpIr),
574 BitwiseOr(BinaryOpIr),
578 BitwiseOrScalar(ScalarOpIr),
582 BitwiseXor(BinaryOpIr),
586 BitwiseXorScalar(ScalarOpIr),
590 BitwiseNot(UnaryOpIr),
594 BitwiseLeftShift(BinaryOpIr),
598 BitwiseLeftShiftScalar(ScalarOpIr),
602 BitwiseRightShift(BinaryOpIr),
606 BitwiseRightShiftScalar(ScalarOpIr),
610 Matmul(BinaryOpIr),
612}
613
614#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
616pub enum BoolOperationIr {
617 Zeros(TensorIr),
620 Ones(TensorIr),
623 IntoFloat(UnaryOpIr),
625 IntoInt(UnaryOpIr),
627 Not(UnaryOpIr),
629 And(BinaryOpIr),
631 Or(BinaryOpIr),
633}
634
635#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
637pub struct SwapDimsOpIr {
638 pub input: TensorIr,
640 pub out: TensorIr,
642 pub dim1: usize,
644 pub dim2: usize,
646}
647
648#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
650pub struct PermuteOpIr {
651 pub input: TensorIr,
653 pub out: TensorIr,
655 pub axes: Vec<usize>,
657}
658
659#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
661pub struct ExpandOpIr {
662 pub input: TensorIr,
664 pub out: TensorIr,
666 pub shape: Shape,
668}
669
670#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
672pub struct UnfoldOpIr {
673 pub input: TensorIr,
675 pub out: TensorIr,
677
678 pub dim: usize,
680 pub size: usize,
682 pub step: usize,
684}
685
686#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
688pub struct FlipOpIr {
689 pub input: TensorIr,
691 pub out: TensorIr,
693 pub axes: Vec<usize>,
695}
696
697#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
698#[allow(missing_docs)]
699pub struct RandomOpIr {
700 pub out: TensorIr,
701 pub distribution: Distribution,
702}
703
704#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
705pub struct InitOperationIr {
709 pub out: TensorIr,
711}
712
713#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
714#[allow(missing_docs)]
715pub struct BinaryOpIr {
716 pub lhs: TensorIr,
717 pub rhs: TensorIr,
718 pub out: TensorIr,
719}
720
721#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
722#[allow(missing_docs)]
723pub struct CrossOpIr {
724 pub lhs: TensorIr,
725 pub rhs: TensorIr,
726 pub out: TensorIr,
727 pub dim: usize,
728}
729
730#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
731#[allow(missing_docs)]
732pub struct UnaryOpIr {
733 pub input: TensorIr,
734 pub out: TensorIr,
735}
736
737#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
738#[allow(missing_docs)]
739pub struct ScalarOpIr {
740 pub lhs: TensorIr,
741 pub rhs: ScalarIr,
744 pub out: TensorIr,
745}
746
747#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
748#[allow(missing_docs)]
749pub struct ReduceDimOpIr {
750 pub input: TensorIr,
751 pub out: TensorIr,
752 pub axis: usize,
753}
754
755#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
758#[allow(missing_docs)]
759pub struct DimOpIr {
760 pub input: TensorIr,
761 pub out: TensorIr,
762 pub axis: usize,
763}
764
765#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
766#[allow(missing_docs)]
767pub struct GatherOpIr {
768 pub tensor: TensorIr,
769 pub dim: usize,
770 pub indices: TensorIr,
771 pub out: TensorIr,
772}
773
774#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
775#[allow(missing_docs)]
776pub struct ScatterOpIr {
777 pub tensor: TensorIr,
778 pub dim: usize,
779 pub indices: TensorIr,
780 pub value: TensorIr,
781 pub out: TensorIr,
782}
783
784#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
785#[allow(missing_docs)]
786pub struct SelectOpIr {
787 pub tensor: TensorIr,
788 pub dim: usize,
789 pub indices: TensorIr,
790 pub out: TensorIr,
791}
792
793#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
794#[allow(missing_docs)]
795pub struct SelectAssignOpIr {
796 pub tensor: TensorIr,
797 pub dim: usize,
798 pub indices: TensorIr,
799 pub value: TensorIr,
800 pub out: TensorIr,
801}
802
803#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
804#[allow(missing_docs)]
805pub struct SliceOpIr {
806 pub tensor: TensorIr,
807 pub ranges: Vec<Slice>,
808 pub out: TensorIr,
809}
810
811#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
812#[allow(missing_docs)]
813pub struct SliceAssignOpIr {
814 pub tensor: TensorIr,
815 pub ranges: Vec<burn_tensor::Slice>,
816 pub value: TensorIr,
817 pub out: TensorIr,
818}
819
820#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
821#[allow(missing_docs)]
822pub struct MaskWhereOpIr {
823 pub tensor: TensorIr,
824 pub mask: TensorIr,
825 pub value: TensorIr,
826 pub out: TensorIr,
827}
828
829#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
830#[allow(missing_docs)]
831pub struct MaskFillOpIr {
832 pub tensor: TensorIr,
833 pub mask: TensorIr,
834 pub value: ScalarIr,
835 pub out: TensorIr,
836}
837
838#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
839#[allow(missing_docs)]
840pub struct ClampOpIr {
841 pub tensor: TensorIr,
842 pub min: ScalarIr,
843 pub max: ScalarIr,
844 pub out: TensorIr,
845}
846
847#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
848#[allow(missing_docs)]
849pub struct RepeatDimOpIr {
850 pub tensor: TensorIr,
851 pub dim: usize,
852 pub times: usize,
853 pub out: TensorIr,
854}
855
856#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
857#[allow(missing_docs)]
858pub struct CatOpIr {
859 pub tensors: Vec<TensorIr>,
860 pub dim: usize,
861 pub out: TensorIr,
862}
863
864#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
865#[allow(missing_docs)]
866pub struct ReduceDimWithIndicesOpIr {
867 pub tensor: TensorIr,
868 pub dim: usize,
869 pub out: TensorIr,
870 pub out_indices: TensorIr,
871}
872
873#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
874#[allow(missing_docs)]
875pub struct EmbeddingOpIr {
876 pub weights: TensorIr,
877 pub indices: TensorIr,
878 pub out: TensorIr,
879}
880
881#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
882#[allow(missing_docs)]
883pub struct EmbeddingBackwardOpIr {
884 pub weights: TensorIr,
885 pub out_grad: TensorIr,
886 pub indices: TensorIr,
887 pub out: TensorIr,
888}
889
890#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
891#[allow(missing_docs)]
892pub struct Conv1dOpIr {
893 pub x: TensorIr,
894 pub weight: TensorIr,
895 pub bias: Option<TensorIr>,
896 pub options: Conv1dOptionsIr,
897 pub out: TensorIr,
898}
899
900#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
901#[allow(missing_docs)]
902pub struct Conv2dOpIr {
903 pub x: TensorIr,
904 pub weight: TensorIr,
905 pub bias: Option<TensorIr>,
906 pub options: Conv2dOptionsIr,
907 pub out: TensorIr,
908}
909
910#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
911#[allow(missing_docs)]
912pub struct DeformConv2dOpIr {
913 pub x: TensorIr,
914 pub offset: TensorIr,
915 pub weight: TensorIr,
916 pub mask: Option<TensorIr>,
917 pub bias: Option<TensorIr>,
918 pub options: DeformableConv2dOptionsIr,
919 pub out: TensorIr,
920}
921
922#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
923#[allow(missing_docs)]
924pub struct DeformConv2dBackwardOpIr {
925 pub x: TensorIr,
926 pub offset: TensorIr,
927 pub weight: TensorIr,
928 pub mask: Option<TensorIr>,
929 pub bias: Option<TensorIr>,
930 pub out_grad: TensorIr,
931 pub options: DeformableConv2dOptionsIr,
932 pub input_grad: TensorIr,
933 pub offset_grad: TensorIr,
934 pub weight_grad: TensorIr,
935 pub mask_grad: Option<TensorIr>,
936 pub bias_grad: Option<TensorIr>,
937}
938
939#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
940#[allow(missing_docs)]
941pub struct Conv3dOpIr {
942 pub x: TensorIr,
943 pub weight: TensorIr,
944 pub bias: Option<TensorIr>,
945 pub options: Conv3dOptionsIr,
946 pub out: TensorIr,
947}
948
949#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
950#[allow(missing_docs)]
951pub struct ConvTranspose1dOpIr {
952 pub x: TensorIr,
953 pub weight: TensorIr,
954 pub bias: Option<TensorIr>,
955 pub options: ConvTranspose1dOptionsIr,
956 pub out: TensorIr,
957}
958
959#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
960#[allow(missing_docs)]
961pub struct ConvTranspose2dOpIr {
962 pub x: TensorIr,
963 pub weight: TensorIr,
964 pub bias: Option<TensorIr>,
965 pub options: ConvTranspose2dOptionsIr,
966 pub out: TensorIr,
967}
968
969#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
970#[allow(missing_docs)]
971pub struct ConvTranspose3dOpIr {
972 pub x: TensorIr,
973 pub weight: TensorIr,
974 pub bias: Option<TensorIr>,
975 pub options: ConvTranspose3dOptionsIr,
976 pub out: TensorIr,
977}
978
979#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
980#[allow(missing_docs)]
981pub struct Conv1dOptionsIr {
982 pub stride: [usize; 1],
983 pub padding: [usize; 1],
984 pub dilation: [usize; 1],
985 pub groups: usize,
986}
987
988#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
989#[allow(missing_docs)]
990pub struct Conv2dOptionsIr {
991 pub stride: [usize; 2],
992 pub padding: [usize; 2],
993 pub dilation: [usize; 2],
994 pub groups: usize,
995}
996
997#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
998#[allow(missing_docs)]
999pub struct DeformableConv2dOptionsIr {
1000 pub stride: [usize; 2],
1001 pub padding: [usize; 2],
1002 pub dilation: [usize; 2],
1003 pub weight_groups: usize,
1004 pub offset_groups: usize,
1005}
1006
1007#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1008#[allow(missing_docs)]
1009pub struct Conv3dOptionsIr {
1010 pub stride: [usize; 3],
1011 pub padding: [usize; 3],
1012 pub dilation: [usize; 3],
1013 pub groups: usize,
1014}
1015
1016#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1017#[allow(missing_docs)]
1018pub struct ConvTranspose1dOptionsIr {
1019 pub stride: [usize; 1],
1020 pub padding: [usize; 1],
1021 pub padding_out: [usize; 1],
1022 pub dilation: [usize; 1],
1023 pub groups: usize,
1024}
1025
1026#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1027#[allow(missing_docs)]
1028pub struct ConvTranspose2dOptionsIr {
1029 pub stride: [usize; 2],
1030 pub padding: [usize; 2],
1031 pub padding_out: [usize; 2],
1032 pub dilation: [usize; 2],
1033 pub groups: usize,
1034}
1035
1036#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1037#[allow(missing_docs)]
1038pub struct ConvTranspose3dOptionsIr {
1039 pub stride: [usize; 3],
1040 pub padding: [usize; 3],
1041 pub padding_out: [usize; 3],
1042 pub dilation: [usize; 3],
1043 pub groups: usize,
1044}
1045
1046#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
1048pub struct QuantizationParametersIr {
1049 pub scales: TensorIr,
1051}
1052
1053#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1054#[allow(missing_docs)]
1055pub struct QuantizeOpIr {
1056 pub tensor: TensorIr,
1057 pub qparams: QuantizationParametersIr,
1058 pub scheme: QuantScheme,
1059 pub out: TensorIr,
1060}
1061
1062#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1063#[allow(missing_docs)]
1064pub struct DequantizeOpIr {
1065 pub input: TensorIr,
1066 pub out: TensorIr,
1067}
1068
1069impl From<ConvOptions<1>> for Conv1dOptionsIr {
1070 fn from(value: ConvOptions<1>) -> Self {
1071 Self {
1072 stride: value.stride,
1073 padding: value.padding,
1074 dilation: value.dilation,
1075 groups: value.groups,
1076 }
1077 }
1078}
1079
1080impl From<ConvOptions<2>> for Conv2dOptionsIr {
1081 fn from(value: ConvOptions<2>) -> Self {
1082 Self {
1083 stride: value.stride,
1084 padding: value.padding,
1085 dilation: value.dilation,
1086 groups: value.groups,
1087 }
1088 }
1089}
1090
1091impl From<ConvOptions<3>> for Conv3dOptionsIr {
1092 fn from(value: ConvOptions<3>) -> Self {
1093 Self {
1094 stride: value.stride,
1095 padding: value.padding,
1096 dilation: value.dilation,
1097 groups: value.groups,
1098 }
1099 }
1100}
1101
1102impl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {
1103 fn from(value: DeformConvOptions<2>) -> Self {
1104 Self {
1105 stride: value.stride,
1106 padding: value.padding,
1107 dilation: value.dilation,
1108 weight_groups: value.weight_groups,
1109 offset_groups: value.offset_groups,
1110 }
1111 }
1112}
1113
1114impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {
1115 fn from(value: ConvTransposeOptions<1>) -> Self {
1116 Self {
1117 stride: value.stride,
1118 padding: value.padding,
1119 padding_out: value.padding_out,
1120 dilation: value.dilation,
1121 groups: value.groups,
1122 }
1123 }
1124}
1125
1126impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {
1127 fn from(value: ConvTransposeOptions<2>) -> Self {
1128 Self {
1129 stride: value.stride,
1130 padding: value.padding,
1131 padding_out: value.padding_out,
1132 dilation: value.dilation,
1133 groups: value.groups,
1134 }
1135 }
1136}
1137
1138impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {
1139 fn from(value: ConvTransposeOptions<3>) -> Self {
1140 Self {
1141 stride: value.stride,
1142 padding: value.padding,
1143 padding_out: value.padding_out,
1144 dilation: value.dilation,
1145 groups: value.groups,
1146 }
1147 }
1148}
1149
1150impl From<Conv1dOptionsIr> for ConvOptions<1> {
1151 fn from(val: Conv1dOptionsIr) -> Self {
1152 ConvOptions {
1153 stride: val.stride,
1154 padding: val.padding,
1155 dilation: val.dilation,
1156 groups: val.groups,
1157 }
1158 }
1159}
1160
1161impl From<Conv2dOptionsIr> for ConvOptions<2> {
1162 fn from(val: Conv2dOptionsIr) -> Self {
1163 ConvOptions {
1164 stride: val.stride,
1165 padding: val.padding,
1166 dilation: val.dilation,
1167 groups: val.groups,
1168 }
1169 }
1170}
1171
1172impl From<Conv3dOptionsIr> for ConvOptions<3> {
1173 fn from(val: Conv3dOptionsIr) -> Self {
1174 ConvOptions {
1175 stride: val.stride,
1176 padding: val.padding,
1177 dilation: val.dilation,
1178 groups: val.groups,
1179 }
1180 }
1181}
1182
1183impl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {
1184 fn from(value: DeformableConv2dOptionsIr) -> Self {
1185 DeformConvOptions {
1186 stride: value.stride,
1187 padding: value.padding,
1188 dilation: value.dilation,
1189 weight_groups: value.weight_groups,
1190 offset_groups: value.offset_groups,
1191 }
1192 }
1193}
1194
1195impl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {
1196 fn from(val: ConvTranspose1dOptionsIr) -> Self {
1197 ConvTransposeOptions {
1198 stride: val.stride,
1199 padding: val.padding,
1200 padding_out: val.padding_out,
1201 dilation: val.dilation,
1202 groups: val.groups,
1203 }
1204 }
1205}
1206
1207impl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {
1208 fn from(val: ConvTranspose2dOptionsIr) -> Self {
1209 ConvTransposeOptions {
1210 stride: val.stride,
1211 padding: val.padding,
1212 padding_out: val.padding_out,
1213 dilation: val.dilation,
1214 groups: val.groups,
1215 }
1216 }
1217}
1218
1219impl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {
1220 fn from(val: ConvTranspose3dOptionsIr) -> Self {
1221 ConvTransposeOptions {
1222 stride: val.stride,
1223 padding: val.padding,
1224 padding_out: val.padding_out,
1225 dilation: val.dilation,
1226 groups: val.groups,
1227 }
1228 }
1229}
1230
1231#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1232#[allow(missing_docs)]
1233pub struct AvgPool1dOpIr {
1234 pub x: TensorIr,
1235 pub kernel_size: usize,
1236 pub stride: usize,
1237 pub padding: usize,
1238 pub count_include_pad: bool,
1239 pub out: TensorIr,
1240}
1241
1242#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1243#[allow(missing_docs)]
1244pub struct AvgPool2dOpIr {
1245 pub x: TensorIr,
1246 pub kernel_size: [usize; 2],
1247 pub stride: [usize; 2],
1248 pub padding: [usize; 2],
1249 pub count_include_pad: bool,
1250 pub out: TensorIr,
1251}
1252
1253#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1254#[allow(missing_docs)]
1255pub struct AvgPool1dBackwardOpIr {
1256 pub x: TensorIr,
1257 pub grad: TensorIr,
1258 pub kernel_size: usize,
1259 pub stride: usize,
1260 pub padding: usize,
1261 pub count_include_pad: bool,
1262 pub out: TensorIr,
1263}
1264
1265#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1266#[allow(missing_docs)]
1267pub struct AvgPool2dBackwardOpIr {
1268 pub x: TensorIr,
1269 pub grad: TensorIr,
1270 pub kernel_size: [usize; 2],
1271 pub stride: [usize; 2],
1272 pub padding: [usize; 2],
1273 pub count_include_pad: bool,
1274 pub out: TensorIr,
1275}
1276
1277#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1278#[allow(missing_docs)]
1279pub struct AdaptiveAvgPool1dOpIr {
1280 pub x: TensorIr,
1281 pub output_size: usize,
1282 pub out: TensorIr,
1283}
1284
1285#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1286#[allow(missing_docs)]
1287pub struct AdaptiveAvgPool2dOpIr {
1288 pub x: TensorIr,
1289 pub output_size: [usize; 2],
1290 pub out: TensorIr,
1291}
1292
1293#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1294#[allow(missing_docs)]
1295pub struct AdaptiveAvgPool1dBackwardOpIr {
1296 pub x: TensorIr,
1297 pub grad: TensorIr,
1298 pub out: TensorIr,
1299}
1300
1301#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1302#[allow(missing_docs)]
1303pub struct AdaptiveAvgPool2dBackwardOpIr {
1304 pub x: TensorIr,
1305 pub grad: TensorIr,
1306 pub out: TensorIr,
1307}
1308
1309#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1310#[allow(missing_docs)]
1311pub struct MaxPool1dOpIr {
1312 pub x: TensorIr,
1313 pub kernel_size: usize,
1314 pub stride: usize,
1315 pub padding: usize,
1316 pub dilation: usize,
1317 pub out: TensorIr,
1318}
1319
1320#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1321#[allow(missing_docs)]
1322pub struct MaxPool1dWithIndicesOpIr {
1323 pub x: TensorIr,
1324 pub kernel_size: usize,
1325 pub stride: usize,
1326 pub padding: usize,
1327 pub dilation: usize,
1328 pub out: TensorIr,
1329 pub out_indices: TensorIr,
1330}
1331
1332#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1333#[allow(missing_docs)]
1334pub struct MaxPool1dWithIndicesBackwardOpIr {
1335 pub x: TensorIr,
1336 pub grad: TensorIr,
1337 pub indices: TensorIr,
1338 pub kernel_size: usize,
1339 pub stride: usize,
1340 pub padding: usize,
1341 pub dilation: usize,
1342 pub out: TensorIr,
1343}
1344
1345#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1346#[allow(missing_docs)]
1347pub struct MaxPool2dOpIr {
1348 pub x: TensorIr,
1349 pub kernel_size: [usize; 2],
1350 pub stride: [usize; 2],
1351 pub padding: [usize; 2],
1352 pub dilation: [usize; 2],
1353 pub out: TensorIr,
1354}
1355
1356#[allow(missing_docs)]
1357#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1358pub struct MaxPool2dWithIndicesOpIr {
1359 pub x: TensorIr,
1360 pub kernel_size: [usize; 2],
1361 pub stride: [usize; 2],
1362 pub padding: [usize; 2],
1363 pub dilation: [usize; 2],
1364 pub out: TensorIr,
1365 pub out_indices: TensorIr,
1366}
1367
1368#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1369#[allow(missing_docs)]
1370pub struct MaxPool2dWithIndicesBackwardOpIr {
1371 pub x: TensorIr,
1372 pub grad: TensorIr,
1373 pub indices: TensorIr,
1374 pub kernel_size: [usize; 2],
1375 pub stride: [usize; 2],
1376 pub padding: [usize; 2],
1377 pub dilation: [usize; 2],
1378 pub out: TensorIr,
1379}
1380
1381#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1382#[allow(missing_docs)]
1383pub enum InterpolateModeIr {
1384 Nearest,
1385 Bilinear,
1386 Bicubic,
1387}
1388
1389#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1390#[allow(missing_docs)]
1391pub struct InterpolateOptionsIr {
1392 pub mode: InterpolateModeIr,
1393}
1394
1395#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1396#[allow(missing_docs)]
1397pub struct InterpolateOpIr {
1398 pub x: TensorIr,
1399 pub output_size: [usize; 2],
1400 pub options: InterpolateOptionsIr,
1401 pub out: TensorIr,
1402}
1403
1404impl From<InterpolateModeIr> for InterpolateMode {
1405 fn from(val: InterpolateModeIr) -> Self {
1406 match val {
1407 InterpolateModeIr::Nearest => Self::Nearest,
1408 InterpolateModeIr::Bilinear => Self::Bilinear,
1409 InterpolateModeIr::Bicubic => Self::Bicubic,
1410 }
1411 }
1412}
1413
1414impl From<InterpolateOptionsIr> for InterpolateOptions {
1415 fn from(val: InterpolateOptionsIr) -> Self {
1416 Self {
1417 mode: val.mode.into(),
1418 }
1419 }
1420}
1421
1422impl From<InterpolateMode> for InterpolateModeIr {
1423 fn from(val: InterpolateMode) -> Self {
1424 match val {
1425 InterpolateMode::Nearest => Self::Nearest,
1426 InterpolateMode::Bilinear => Self::Bilinear,
1427 InterpolateMode::Bicubic => Self::Bicubic,
1428 }
1429 }
1430}
1431
1432impl From<InterpolateOptions> for InterpolateOptionsIr {
1433 fn from(val: InterpolateOptions) -> Self {
1434 Self {
1435 mode: val.mode.into(),
1436 }
1437 }
1438}
1439
1440#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1441#[allow(missing_docs)]
1442pub struct InterpolateBackwardOpIr {
1443 pub x: TensorIr,
1444 pub grad: TensorIr,
1445 pub output_size: [usize; 2],
1446 pub options: InterpolateOptionsIr,
1447 pub out: TensorIr,
1448}
1449
1450impl OperationIr {
1451 pub fn nodes(&self) -> Vec<&TensorIr> {
1453 match self {
1454 OperationIr::BaseFloat(repr) => repr.nodes(),
1455 OperationIr::BaseInt(repr) => repr.nodes(),
1456 OperationIr::BaseBool(repr) => repr.nodes(),
1457 OperationIr::NumericFloat(_dtype, repr) => repr.nodes(),
1458 OperationIr::NumericInt(_dtype, repr) => repr.nodes(),
1459 OperationIr::Bool(repr) => repr.nodes(),
1460 OperationIr::Int(repr) => repr.nodes(),
1461 OperationIr::Float(_dtype, repr) => repr.nodes(),
1462 OperationIr::Module(repr) => repr.nodes(),
1463 OperationIr::Init(repr) => repr.nodes(),
1464 OperationIr::Custom(repr) => repr.nodes(),
1465 OperationIr::Drop(repr) => vec![repr],
1466 }
1467 }
1468
1469 pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1474 match self {
1475 OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),
1476 OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),
1477 OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),
1478 OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),
1479 OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),
1480 OperationIr::Bool(repr) => repr.mark_read_only(nodes),
1481 OperationIr::Int(repr) => repr.mark_read_only(nodes),
1482 OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),
1483 OperationIr::Module(repr) => repr.mark_read_only(nodes),
1484 OperationIr::Init(_) => Vec::new(),
1485 OperationIr::Drop(repr) => {
1486 let mut output = Vec::new();
1487 repr.mark_read_only(nodes, &mut output);
1488 output
1489 }
1490 OperationIr::Custom(repr) => {
1491 let mut output = Vec::new();
1492
1493 for input in repr.inputs.iter_mut() {
1494 input.mark_read_only(nodes, &mut output);
1495 }
1496
1497 output
1498 }
1499 }
1500 }
1501}
1502
1503impl BaseOperationIr {
1504 fn nodes(&self) -> Vec<&TensorIr> {
1505 match self {
1506 BaseOperationIr::ToDevice(repr) => vec![repr],
1507 BaseOperationIr::Reshape(repr) => {
1508 vec![&repr.input, &repr.out]
1509 }
1510 BaseOperationIr::SwapDims(repr) => {
1511 vec![&repr.input, &repr.out]
1512 }
1513 BaseOperationIr::Permute(repr) => {
1514 vec![&repr.input, &repr.out]
1515 }
1516
1517 BaseOperationIr::Expand(repr) => {
1518 vec![&repr.input, &repr.out]
1519 }
1520
1521 BaseOperationIr::Flip(repr) => {
1522 vec![&repr.input, &repr.out]
1523 }
1524 BaseOperationIr::Slice(repr) => {
1525 vec![&repr.tensor, &repr.out]
1526 }
1527 BaseOperationIr::SliceAssign(repr) => {
1528 vec![&repr.tensor, &repr.value, &repr.out]
1529 }
1530 BaseOperationIr::Equal(repr) => {
1531 vec![&repr.lhs, &repr.rhs, &repr.out]
1532 }
1533 BaseOperationIr::RepeatDim(repr) => {
1534 vec![&repr.tensor, &repr.out]
1535 }
1536 BaseOperationIr::Cat(repr) => {
1537 let mut tensors: Vec<_> = repr.tensors.iter().collect();
1538 tensors.push(&repr.out);
1539 tensors
1540 }
1541 BaseOperationIr::Cast(repr) => vec![&repr.input, &repr.out],
1542 BaseOperationIr::CumSum(repr) => vec![&repr.input, &repr.out],
1543 BaseOperationIr::CumProd(repr) => vec![&repr.input, &repr.out],
1544 BaseOperationIr::CumMin(repr) => vec![&repr.input, &repr.out],
1545 BaseOperationIr::CumMax(repr) => vec![&repr.input, &repr.out],
1546 BaseOperationIr::Empty(repr) => vec![repr],
1547 BaseOperationIr::Unfold(repr) => {
1548 vec![&repr.input, &repr.out]
1549 }
1550 }
1551 }
1552
1553 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1554 let mut output = Vec::new();
1555
1556 match self {
1557 BaseOperationIr::ToDevice(repr) => {
1558 repr.mark_read_only(nodes, &mut output);
1559 }
1560 BaseOperationIr::Reshape(repr) => {
1561 repr.input.mark_read_only(nodes, &mut output);
1562 }
1563 BaseOperationIr::SwapDims(repr) => {
1564 repr.input.mark_read_only(nodes, &mut output);
1565 }
1566 BaseOperationIr::Permute(repr) => {
1567 repr.input.mark_read_only(nodes, &mut output);
1568 }
1569
1570 BaseOperationIr::Expand(repr) => {
1571 repr.input.mark_read_only(nodes, &mut output);
1572 }
1573
1574 BaseOperationIr::Flip(repr) => {
1575 repr.input.mark_read_only(nodes, &mut output);
1576 }
1577 BaseOperationIr::Slice(repr) => {
1578 repr.tensor.mark_read_only(nodes, &mut output);
1579 }
1580 BaseOperationIr::SliceAssign(repr) => {
1581 repr.tensor.mark_read_only(nodes, &mut output);
1582 repr.value.mark_read_only(nodes, &mut output);
1583 }
1584 BaseOperationIr::Equal(repr) => {
1585 repr.lhs.mark_read_only(nodes, &mut output);
1586 repr.rhs.mark_read_only(nodes, &mut output);
1587 }
1588 BaseOperationIr::RepeatDim(repr) => {
1589 repr.tensor.mark_read_only(nodes, &mut output);
1590 }
1591 BaseOperationIr::Cat(repr) => {
1592 for t in repr.tensors.iter_mut() {
1593 t.mark_read_only(nodes, &mut output);
1594 }
1595 }
1596 BaseOperationIr::Cast(repr) => {
1597 repr.input.mark_read_only(nodes, &mut output);
1598 }
1599 BaseOperationIr::CumSum(repr) => {
1600 repr.input.mark_read_only(nodes, &mut output);
1601 }
1602 BaseOperationIr::CumProd(repr) => {
1603 repr.input.mark_read_only(nodes, &mut output);
1604 }
1605 BaseOperationIr::CumMin(repr) => {
1606 repr.input.mark_read_only(nodes, &mut output);
1607 }
1608 BaseOperationIr::CumMax(repr) => {
1609 repr.input.mark_read_only(nodes, &mut output);
1610 }
1611 BaseOperationIr::Unfold(repr) => {
1612 repr.input.mark_read_only(nodes, &mut output);
1613 }
1614 BaseOperationIr::Empty(_) => {}
1615 };
1616
1617 output
1618 }
1619}
1620
1621impl NumericOperationIr {
1622 fn nodes(&self) -> Vec<&TensorIr> {
1623 match self {
1624 NumericOperationIr::Add(repr) => {
1625 vec![&repr.lhs, &repr.rhs, &repr.out]
1626 }
1627 NumericOperationIr::AddScalar(repr) => {
1628 vec![&repr.lhs, &repr.out]
1629 }
1630 NumericOperationIr::Sub(repr) => {
1631 vec![&repr.lhs, &repr.rhs, &repr.out]
1632 }
1633 NumericOperationIr::SubScalar(repr) => {
1634 vec![&repr.lhs, &repr.out]
1635 }
1636 NumericOperationIr::Mul(repr) => {
1637 vec![&repr.lhs, &repr.rhs, &repr.out]
1638 }
1639 NumericOperationIr::MulScalar(repr) => {
1640 vec![&repr.lhs, &repr.out]
1641 }
1642 NumericOperationIr::Div(repr) => {
1643 vec![&repr.lhs, &repr.rhs, &repr.out]
1644 }
1645 NumericOperationIr::DivScalar(repr) => {
1646 vec![&repr.lhs, &repr.out]
1647 }
1648 NumericOperationIr::Rem(repr) => {
1649 vec![&repr.lhs, &repr.rhs, &repr.out]
1650 }
1651 NumericOperationIr::RemScalar(repr) => {
1652 vec![&repr.lhs, &repr.out]
1653 }
1654 NumericOperationIr::Ones(repr) => vec![repr],
1655 NumericOperationIr::Gather(repr) => {
1656 vec![&repr.tensor, &repr.indices, &repr.out]
1657 }
1658 NumericOperationIr::Scatter(repr) => {
1659 vec![&repr.tensor, &repr.indices, &repr.value, &repr.out]
1660 }
1661 NumericOperationIr::Select(repr) => {
1662 vec![&repr.tensor, &repr.indices, &repr.out]
1663 }
1664 NumericOperationIr::SelectAssign(repr) => {
1665 vec![&repr.tensor, &repr.indices, &repr.value, &repr.out]
1666 }
1667 NumericOperationIr::MaskWhere(repr) => {
1668 vec![&repr.tensor, &repr.mask, &repr.value, &repr.out]
1669 }
1670 NumericOperationIr::MaskFill(repr) => {
1671 vec![&repr.tensor, &repr.mask, &repr.out]
1672 }
1673 NumericOperationIr::EqualElem(repr) => {
1674 vec![&repr.lhs, &repr.out]
1675 }
1676 NumericOperationIr::GreaterElem(repr) => {
1677 vec![&repr.lhs, &repr.out]
1678 }
1679 NumericOperationIr::GreaterEqualElem(repr) => {
1680 vec![&repr.lhs, &repr.out]
1681 }
1682 NumericOperationIr::LowerElem(repr) => {
1683 vec![&repr.lhs, &repr.out]
1684 }
1685 NumericOperationIr::LowerEqualElem(repr) => {
1686 vec![&repr.lhs, &repr.out]
1687 }
1688 NumericOperationIr::Greater(repr) => {
1689 vec![&repr.lhs, &repr.rhs, &repr.out]
1690 }
1691 NumericOperationIr::GreaterEqual(repr) => {
1692 vec![&repr.lhs, &repr.rhs, &repr.out]
1693 }
1694 NumericOperationIr::Lower(repr) => {
1695 vec![&repr.lhs, &repr.rhs, &repr.out]
1696 }
1697 NumericOperationIr::LowerEqual(repr) => {
1698 vec![&repr.lhs, &repr.rhs, &repr.out]
1699 }
1700 NumericOperationIr::ArgMax(repr) => {
1701 vec![&repr.input, &repr.out]
1702 }
1703 NumericOperationIr::ArgMin(repr) => {
1704 vec![&repr.input, &repr.out]
1705 }
1706 NumericOperationIr::Clamp(repr) => {
1707 vec![&repr.tensor, &repr.out]
1708 }
1709 NumericOperationIr::Abs(repr) => {
1710 vec![&repr.input, &repr.out]
1711 }
1712 NumericOperationIr::Zeros(repr) => vec![repr],
1713 NumericOperationIr::Full(repr) => vec![&repr.0],
1714 NumericOperationIr::MeanDim(repr) => {
1715 vec![&repr.input, &repr.out]
1716 }
1717 NumericOperationIr::Mean(repr) => {
1718 vec![&repr.input, &repr.out]
1719 }
1720 NumericOperationIr::Sum(repr) => {
1721 vec![&repr.input, &repr.out]
1722 }
1723 NumericOperationIr::SumDim(repr) => {
1724 vec![&repr.input, &repr.out]
1725 }
1726 NumericOperationIr::Prod(repr) => {
1727 vec![&repr.input, &repr.out]
1728 }
1729 NumericOperationIr::ProdDim(repr) => {
1730 vec![&repr.input, &repr.out]
1731 }
1732 NumericOperationIr::Max(repr) => {
1733 vec![&repr.input, &repr.out]
1734 }
1735 NumericOperationIr::MaxDimWithIndices(repr) => {
1736 vec![&repr.tensor, &repr.out_indices, &repr.out]
1737 }
1738 NumericOperationIr::MinDimWithIndices(repr) => {
1739 vec![&repr.tensor, &repr.out_indices, &repr.out]
1740 }
1741 NumericOperationIr::Min(repr) => {
1742 vec![&repr.input, &repr.out]
1743 }
1744 NumericOperationIr::MaxDim(repr) => {
1745 vec![&repr.input, &repr.out]
1746 }
1747 NumericOperationIr::MinDim(repr) => {
1748 vec![&repr.input, &repr.out]
1749 }
1750 NumericOperationIr::MaxAbs(repr) => {
1751 vec![&repr.input, &repr.out]
1752 }
1753 NumericOperationIr::MaxAbsDim(repr) => {
1754 vec![&repr.input, &repr.out]
1755 }
1756 NumericOperationIr::IntRandom(repr) => {
1757 vec![&repr.out]
1758 }
1759 NumericOperationIr::Powf(repr) => {
1760 vec![&repr.lhs, &repr.rhs, &repr.out]
1761 }
1762 }
1763 }
1764 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1765 let mut output = Vec::new();
1766
1767 match self {
1768 NumericOperationIr::Add(repr) => {
1769 repr.lhs.mark_read_only(nodes, &mut output);
1770 repr.rhs.mark_read_only(nodes, &mut output);
1771 }
1772 NumericOperationIr::AddScalar(repr) => {
1773 repr.lhs.mark_read_only(nodes, &mut output);
1774 }
1775 NumericOperationIr::Sub(repr) => {
1776 repr.lhs.mark_read_only(nodes, &mut output);
1777 repr.rhs.mark_read_only(nodes, &mut output);
1778 }
1779 NumericOperationIr::SubScalar(repr) => {
1780 repr.lhs.mark_read_only(nodes, &mut output);
1781 }
1782 NumericOperationIr::Mul(repr) => {
1783 repr.lhs.mark_read_only(nodes, &mut output);
1784 repr.rhs.mark_read_only(nodes, &mut output);
1785 }
1786 NumericOperationIr::MulScalar(repr) => {
1787 repr.lhs.mark_read_only(nodes, &mut output);
1788 }
1789 NumericOperationIr::Div(repr) => {
1790 repr.lhs.mark_read_only(nodes, &mut output);
1791 repr.rhs.mark_read_only(nodes, &mut output);
1792 }
1793 NumericOperationIr::DivScalar(repr) => {
1794 repr.lhs.mark_read_only(nodes, &mut output);
1795 }
1796 NumericOperationIr::Rem(repr) => {
1797 repr.lhs.mark_read_only(nodes, &mut output);
1798 repr.rhs.mark_read_only(nodes, &mut output);
1799 }
1800 NumericOperationIr::RemScalar(repr) => {
1801 repr.lhs.mark_read_only(nodes, &mut output);
1802 }
1803 NumericOperationIr::Ones(_) => {}
1804 NumericOperationIr::Gather(repr) => {
1805 repr.tensor.mark_read_only(nodes, &mut output);
1806 repr.indices.mark_read_only(nodes, &mut output);
1807 }
1808 NumericOperationIr::Scatter(repr) => {
1809 repr.tensor.mark_read_only(nodes, &mut output);
1810 repr.indices.mark_read_only(nodes, &mut output);
1811 repr.value.mark_read_only(nodes, &mut output);
1812 }
1813 NumericOperationIr::Select(repr) => {
1814 repr.tensor.mark_read_only(nodes, &mut output);
1815 repr.indices.mark_read_only(nodes, &mut output);
1816 }
1817 NumericOperationIr::SelectAssign(repr) => {
1818 repr.tensor.mark_read_only(nodes, &mut output);
1819 repr.indices.mark_read_only(nodes, &mut output);
1820 repr.value.mark_read_only(nodes, &mut output);
1821 }
1822 NumericOperationIr::MaskWhere(repr) => {
1823 repr.tensor.mark_read_only(nodes, &mut output);
1824 repr.mask.mark_read_only(nodes, &mut output);
1825 repr.value.mark_read_only(nodes, &mut output);
1826 }
1827 NumericOperationIr::MaskFill(repr) => {
1828 repr.tensor.mark_read_only(nodes, &mut output);
1829 repr.mask.mark_read_only(nodes, &mut output);
1830 }
1831 NumericOperationIr::EqualElem(repr) => {
1832 repr.lhs.mark_read_only(nodes, &mut output);
1833 }
1834 NumericOperationIr::GreaterElem(repr) => {
1835 repr.lhs.mark_read_only(nodes, &mut output);
1836 }
1837 NumericOperationIr::GreaterEqualElem(repr) => {
1838 repr.lhs.mark_read_only(nodes, &mut output);
1839 }
1840 NumericOperationIr::LowerElem(repr) => {
1841 repr.lhs.mark_read_only(nodes, &mut output);
1842 }
1843 NumericOperationIr::LowerEqualElem(repr) => {
1844 repr.lhs.mark_read_only(nodes, &mut output);
1845 }
1846 NumericOperationIr::Greater(repr) => {
1847 repr.lhs.mark_read_only(nodes, &mut output);
1848 repr.rhs.mark_read_only(nodes, &mut output);
1849 }
1850 NumericOperationIr::GreaterEqual(repr) => {
1851 repr.lhs.mark_read_only(nodes, &mut output);
1852 repr.rhs.mark_read_only(nodes, &mut output);
1853 }
1854 NumericOperationIr::Lower(repr) => {
1855 repr.lhs.mark_read_only(nodes, &mut output);
1856 repr.rhs.mark_read_only(nodes, &mut output);
1857 }
1858 NumericOperationIr::LowerEqual(repr) => {
1859 repr.lhs.mark_read_only(nodes, &mut output);
1860 repr.rhs.mark_read_only(nodes, &mut output);
1861 }
1862 NumericOperationIr::ArgMax(repr) => {
1863 repr.input.mark_read_only(nodes, &mut output);
1864 }
1865 NumericOperationIr::ArgMin(repr) => {
1866 repr.input.mark_read_only(nodes, &mut output);
1867 }
1868 NumericOperationIr::Clamp(repr) => {
1869 repr.tensor.mark_read_only(nodes, &mut output);
1870 }
1871 NumericOperationIr::Abs(repr) => {
1872 repr.input.mark_read_only(nodes, &mut output);
1873 }
1874 NumericOperationIr::Zeros(_) => {}
1875 NumericOperationIr::Full(_) => {}
1876 NumericOperationIr::MeanDim(repr) => {
1877 repr.input.mark_read_only(nodes, &mut output);
1878 }
1879 NumericOperationIr::Mean(repr) => {
1880 repr.input.mark_read_only(nodes, &mut output);
1881 }
1882 NumericOperationIr::Sum(repr) => {
1883 repr.input.mark_read_only(nodes, &mut output);
1884 }
1885 NumericOperationIr::SumDim(repr) => {
1886 repr.input.mark_read_only(nodes, &mut output);
1887 }
1888 NumericOperationIr::Prod(repr) => {
1889 repr.input.mark_read_only(nodes, &mut output);
1890 }
1891 NumericOperationIr::ProdDim(repr) => {
1892 repr.input.mark_read_only(nodes, &mut output);
1893 }
1894 NumericOperationIr::Max(repr) => {
1895 repr.input.mark_read_only(nodes, &mut output);
1896 }
1897 NumericOperationIr::MaxDimWithIndices(repr) => {
1898 repr.tensor.mark_read_only(nodes, &mut output);
1899 }
1900 NumericOperationIr::MinDimWithIndices(repr) => {
1901 repr.tensor.mark_read_only(nodes, &mut output);
1902 }
1903 NumericOperationIr::Min(repr) => {
1904 repr.input.mark_read_only(nodes, &mut output);
1905 }
1906 NumericOperationIr::MaxDim(repr) => {
1907 repr.input.mark_read_only(nodes, &mut output);
1908 }
1909 NumericOperationIr::MinDim(repr) => {
1910 repr.input.mark_read_only(nodes, &mut output);
1911 }
1912 NumericOperationIr::MaxAbs(repr) => {
1913 repr.input.mark_read_only(nodes, &mut output);
1914 }
1915 NumericOperationIr::MaxAbsDim(repr) => {
1916 repr.input.mark_read_only(nodes, &mut output);
1917 }
1918 NumericOperationIr::IntRandom(_) => {}
1919 NumericOperationIr::Powf(repr) => {
1920 repr.lhs.mark_read_only(nodes, &mut output);
1921 repr.rhs.mark_read_only(nodes, &mut output);
1922 }
1923 };
1924
1925 output
1926 }
1927}
1928
1929impl FloatOperationIr {
1930 fn nodes(&self) -> Vec<&TensorIr> {
1931 match self {
1932 FloatOperationIr::Matmul(repr) => {
1933 vec![&repr.lhs, &repr.rhs, &repr.out]
1934 }
1935 FloatOperationIr::Cross(repr) => {
1936 vec![&repr.lhs, &repr.rhs, &repr.out]
1937 }
1938 FloatOperationIr::Random(repr) => vec![&repr.out],
1939 FloatOperationIr::Exp(repr) => vec![&repr.input, &repr.out],
1940 FloatOperationIr::Log(repr) => vec![&repr.input, &repr.out],
1941 FloatOperationIr::Log1p(repr) => vec![&repr.input, &repr.out],
1942 FloatOperationIr::Erf(repr) => vec![&repr.input, &repr.out],
1943 FloatOperationIr::Recip(repr) => vec![&repr.input, &repr.out],
1944 FloatOperationIr::PowfScalar(repr) => vec![&repr.lhs, &repr.out],
1945 FloatOperationIr::Sqrt(repr) => vec![&repr.input, &repr.out],
1946 FloatOperationIr::Cos(repr) => vec![&repr.input, &repr.out],
1947 FloatOperationIr::Sin(repr) => vec![&repr.input, &repr.out],
1948 FloatOperationIr::Tanh(repr) => vec![&repr.input, &repr.out],
1949 FloatOperationIr::Round(repr) => vec![&repr.input, &repr.out],
1950 FloatOperationIr::Floor(repr) => vec![&repr.input, &repr.out],
1951 FloatOperationIr::Ceil(repr) => vec![&repr.input, &repr.out],
1952 FloatOperationIr::Trunc(repr) => vec![&repr.input, &repr.out],
1953 FloatOperationIr::IntoInt(repr) => vec![&repr.input, &repr.out],
1954 FloatOperationIr::Quantize(repr) => vec![&repr.tensor, &repr.qparams.scales, &repr.out],
1955 FloatOperationIr::Dequantize(repr) => vec![&repr.input, &repr.out],
1956 FloatOperationIr::IsNan(repr) => vec![&repr.input, &repr.out],
1957 FloatOperationIr::IsInf(repr) => vec![&repr.input, &repr.out],
1958 }
1959 }
1960
1961 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1962 let mut output = Vec::new();
1963
1964 match self {
1965 FloatOperationIr::Matmul(repr) => {
1966 repr.lhs.mark_read_only(nodes, &mut output);
1967 repr.rhs.mark_read_only(nodes, &mut output);
1968 }
1969 FloatOperationIr::Cross(repr) => {
1970 repr.lhs.mark_read_only(nodes, &mut output);
1971 repr.rhs.mark_read_only(nodes, &mut output);
1972 }
1973 FloatOperationIr::Random(_) => {}
1974 FloatOperationIr::Exp(repr) => {
1975 repr.input.mark_read_only(nodes, &mut output);
1976 }
1977 FloatOperationIr::Log(repr) => {
1978 repr.input.mark_read_only(nodes, &mut output);
1979 }
1980 FloatOperationIr::Log1p(repr) => {
1981 repr.input.mark_read_only(nodes, &mut output);
1982 }
1983 FloatOperationIr::Erf(repr) => {
1984 repr.input.mark_read_only(nodes, &mut output);
1985 }
1986 FloatOperationIr::Recip(repr) => {
1987 repr.input.mark_read_only(nodes, &mut output);
1988 }
1989 FloatOperationIr::PowfScalar(repr) => {
1990 repr.lhs.mark_read_only(nodes, &mut output);
1991 }
1992 FloatOperationIr::Sqrt(repr) => {
1993 repr.input.mark_read_only(nodes, &mut output);
1994 }
1995 FloatOperationIr::Cos(repr) => {
1996 repr.input.mark_read_only(nodes, &mut output);
1997 }
1998 FloatOperationIr::Sin(repr) => {
1999 repr.input.mark_read_only(nodes, &mut output);
2000 }
2001 FloatOperationIr::Tanh(repr) => {
2002 repr.input.mark_read_only(nodes, &mut output);
2003 }
2004 FloatOperationIr::Round(repr) => {
2005 repr.input.mark_read_only(nodes, &mut output);
2006 }
2007 FloatOperationIr::Floor(repr) => {
2008 repr.input.mark_read_only(nodes, &mut output);
2009 }
2010 FloatOperationIr::Ceil(repr) => {
2011 repr.input.mark_read_only(nodes, &mut output);
2012 }
2013 FloatOperationIr::Trunc(repr) => {
2014 repr.input.mark_read_only(nodes, &mut output);
2015 }
2016 FloatOperationIr::Quantize(repr) => {
2017 repr.tensor.mark_read_only(nodes, &mut output);
2018 repr.qparams.scales.mark_read_only(nodes, &mut output);
2019 }
2020 FloatOperationIr::Dequantize(repr) => {
2021 repr.input.mark_read_only(nodes, &mut output);
2022 }
2023 FloatOperationIr::IntoInt(repr) => {
2024 repr.input.mark_read_only(nodes, &mut output);
2025 }
2026 FloatOperationIr::IsNan(repr) => {
2027 repr.input.mark_read_only(nodes, &mut output);
2028 }
2029 FloatOperationIr::IsInf(repr) => {
2030 repr.input.mark_read_only(nodes, &mut output);
2031 }
2032 };
2033
2034 output
2035 }
2036}
2037
2038impl IntOperationIr {
2039 fn nodes(&self) -> Vec<&TensorIr> {
2040 match self {
2041 IntOperationIr::Matmul(repr) => {
2042 vec![&repr.lhs, &repr.rhs, &repr.out]
2043 }
2044 IntOperationIr::IntoFloat(repr) => vec![&repr.input, &repr.out],
2045 IntOperationIr::BitwiseAnd(repr) => {
2046 vec![&repr.lhs, &repr.rhs, &repr.out]
2047 }
2048 IntOperationIr::BitwiseAndScalar(repr) => {
2049 vec![&repr.lhs, &repr.out]
2050 }
2051 IntOperationIr::BitwiseOr(repr) => {
2052 vec![&repr.lhs, &repr.rhs, &repr.out]
2053 }
2054 IntOperationIr::BitwiseOrScalar(repr) => {
2055 vec![&repr.lhs, &repr.out]
2056 }
2057 IntOperationIr::BitwiseXor(repr) => {
2058 vec![&repr.lhs, &repr.rhs, &repr.out]
2059 }
2060 IntOperationIr::BitwiseXorScalar(repr) => {
2061 vec![&repr.lhs, &repr.out]
2062 }
2063 IntOperationIr::BitwiseNot(repr) => {
2064 vec![&repr.input, &repr.out]
2065 }
2066 IntOperationIr::BitwiseLeftShift(repr) => {
2067 vec![&repr.lhs, &repr.rhs, &repr.out]
2068 }
2069 IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2070 vec![&repr.lhs, &repr.out]
2071 }
2072 IntOperationIr::BitwiseRightShift(repr) => {
2073 vec![&repr.lhs, &repr.rhs, &repr.out]
2074 }
2075 IntOperationIr::BitwiseRightShiftScalar(repr) => {
2076 vec![&repr.lhs, &repr.out]
2077 }
2078 }
2079 }
2080
2081 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2082 let mut output = Vec::new();
2083
2084 match self {
2085 IntOperationIr::Matmul(repr) => {
2086 repr.lhs.mark_read_only(nodes, &mut output);
2087 repr.rhs.mark_read_only(nodes, &mut output);
2088 }
2089 IntOperationIr::IntoFloat(repr) => {
2090 repr.input.mark_read_only(nodes, &mut output);
2091 }
2092 IntOperationIr::BitwiseAnd(repr) => {
2093 repr.lhs.mark_read_only(nodes, &mut output);
2094 repr.rhs.mark_read_only(nodes, &mut output);
2095 }
2096 IntOperationIr::BitwiseAndScalar(repr) => {
2097 repr.lhs.mark_read_only(nodes, &mut output);
2098 }
2099 IntOperationIr::BitwiseOr(repr) => {
2100 repr.lhs.mark_read_only(nodes, &mut output);
2101 repr.rhs.mark_read_only(nodes, &mut output);
2102 }
2103 IntOperationIr::BitwiseOrScalar(repr) => {
2104 repr.lhs.mark_read_only(nodes, &mut output);
2105 }
2106 IntOperationIr::BitwiseXor(repr) => {
2107 repr.lhs.mark_read_only(nodes, &mut output);
2108 repr.rhs.mark_read_only(nodes, &mut output);
2109 }
2110 IntOperationIr::BitwiseXorScalar(repr) => {
2111 repr.lhs.mark_read_only(nodes, &mut output);
2112 }
2113 IntOperationIr::BitwiseNot(repr) => {
2114 repr.input.mark_read_only(nodes, &mut output);
2115 }
2116 IntOperationIr::BitwiseLeftShift(repr) => {
2117 repr.lhs.mark_read_only(nodes, &mut output);
2118 repr.rhs.mark_read_only(nodes, &mut output);
2119 }
2120 IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2121 repr.lhs.mark_read_only(nodes, &mut output);
2122 }
2123 IntOperationIr::BitwiseRightShift(repr) => {
2124 repr.lhs.mark_read_only(nodes, &mut output);
2125 repr.rhs.mark_read_only(nodes, &mut output);
2126 }
2127 IntOperationIr::BitwiseRightShiftScalar(repr) => {
2128 repr.lhs.mark_read_only(nodes, &mut output);
2129 }
2130 };
2131
2132 output
2133 }
2134}
2135
2136impl BoolOperationIr {
2137 fn nodes(&self) -> Vec<&TensorIr> {
2138 match self {
2139 BoolOperationIr::Zeros(repr) => vec![repr],
2140 BoolOperationIr::Ones(repr) => vec![repr],
2141 BoolOperationIr::IntoFloat(repr) => vec![&repr.input, &repr.out],
2142 BoolOperationIr::IntoInt(repr) => vec![&repr.input, &repr.out],
2143 BoolOperationIr::Not(repr) => vec![&repr.input, &repr.out],
2144 BoolOperationIr::And(repr) => vec![&repr.lhs, &repr.rhs, &repr.out],
2145 BoolOperationIr::Or(repr) => vec![&repr.lhs, &repr.rhs, &repr.out],
2146 }
2147 }
2148 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2149 let mut output = Vec::new();
2150
2151 match self {
2152 BoolOperationIr::Zeros(_) => {}
2153 BoolOperationIr::Ones(_) => {}
2154 BoolOperationIr::IntoFloat(repr) => {
2155 repr.input.mark_read_only(nodes, &mut output);
2156 }
2157 BoolOperationIr::IntoInt(repr) => {
2158 repr.input.mark_read_only(nodes, &mut output);
2159 }
2160 BoolOperationIr::Not(repr) => {
2161 repr.input.mark_read_only(nodes, &mut output);
2162 }
2163 BoolOperationIr::And(repr) => {
2164 repr.lhs.mark_read_only(nodes, &mut output);
2165 repr.rhs.mark_read_only(nodes, &mut output);
2166 }
2167 BoolOperationIr::Or(repr) => {
2168 repr.lhs.mark_read_only(nodes, &mut output);
2169 repr.rhs.mark_read_only(nodes, &mut output);
2170 }
2171 };
2172
2173 output
2174 }
2175}
2176
2177impl ModuleOperationIr {
2178 fn nodes(&self) -> Vec<&TensorIr> {
2179 match self {
2180 ModuleOperationIr::Embedding(repr) => {
2181 vec![&repr.weights, &repr.indices, &repr.out]
2182 }
2183 ModuleOperationIr::EmbeddingBackward(repr) => {
2184 vec![&repr.weights, &repr.out_grad, &repr.indices, &repr.out]
2185 }
2186 ModuleOperationIr::Conv1d(repr) => {
2187 if let Some(bias) = &repr.bias {
2188 vec![&repr.x, &repr.weight, &bias, &repr.out]
2189 } else {
2190 vec![&repr.x, &repr.weight, &repr.out]
2191 }
2192 }
2193 ModuleOperationIr::Conv2d(repr) => {
2194 if let Some(bias) = &repr.bias {
2195 vec![&repr.x, &repr.weight, &bias, &repr.out]
2196 } else {
2197 vec![&repr.x, &repr.weight, &repr.out]
2198 }
2199 }
2200 ModuleOperationIr::Conv3d(repr) => {
2201 if let Some(bias) = &repr.bias {
2202 vec![&repr.x, &repr.weight, &bias, &repr.out]
2203 } else {
2204 vec![&repr.x, &repr.weight, &repr.out]
2205 }
2206 }
2207 ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {
2208 (Some(mask), Some(bias)) => vec![&repr.x, &repr.offset, &repr.weight, &mask, &bias],
2209 (Some(mask), None) => vec![&repr.x, &repr.offset, &repr.weight, &mask],
2210 (None, Some(bias)) => vec![&repr.x, &repr.offset, &repr.weight, &bias],
2211 (None, None) => vec![&repr.x, &repr.offset, &repr.weight],
2212 },
2213 ModuleOperationIr::DeformableConv2dBackward(repr) => {
2214 let mut nodes = Vec::with_capacity(6);
2215 nodes.push(&repr.x);
2216 nodes.push(&repr.offset);
2217 nodes.push(&repr.weight);
2218 nodes.push(&repr.out_grad);
2219
2220 if let Some(mask) = repr.mask.as_ref() {
2221 nodes.push(mask);
2222 }
2223 if let Some(bias) = repr.bias.as_ref() {
2224 nodes.push(bias);
2225 }
2226
2227 nodes
2228 }
2229 ModuleOperationIr::ConvTranspose1d(repr) => {
2230 if let Some(bias) = &repr.bias {
2231 vec![&repr.x, &repr.weight, &bias, &repr.out]
2232 } else {
2233 vec![&repr.x, &repr.weight, &repr.out]
2234 }
2235 }
2236 ModuleOperationIr::ConvTranspose2d(repr) => {
2237 if let Some(bias) = &repr.bias {
2238 vec![&repr.x, &repr.weight, &bias, &repr.out]
2239 } else {
2240 vec![&repr.x, &repr.weight, &repr.out]
2241 }
2242 }
2243 ModuleOperationIr::ConvTranspose3d(repr) => {
2244 if let Some(bias) = &repr.bias {
2245 vec![&repr.x, &repr.weight, &bias, &repr.out]
2246 } else {
2247 vec![&repr.x, &repr.weight, &repr.out]
2248 }
2249 }
2250 ModuleOperationIr::AvgPool1d(repr) => {
2251 vec![&repr.x, &repr.out]
2252 }
2253 ModuleOperationIr::AvgPool2d(repr) => {
2254 vec![&repr.x, &repr.out]
2255 }
2256 ModuleOperationIr::AvgPool1dBackward(repr) => {
2257 vec![&repr.x, &repr.out, &repr.grad]
2258 }
2259 ModuleOperationIr::AvgPool2dBackward(repr) => {
2260 vec![&repr.x, &repr.out, &repr.grad]
2261 }
2262 ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
2263 vec![&repr.x, &repr.out]
2264 }
2265 ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
2266 vec![&repr.x, &repr.out]
2267 }
2268 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2269 vec![&repr.x, &repr.out, &repr.grad]
2270 }
2271 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2272 vec![&repr.x, &repr.out, &repr.grad]
2273 }
2274 ModuleOperationIr::MaxPool1d(repr) => {
2275 vec![&repr.x, &repr.out]
2276 }
2277 ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2278 vec![&repr.x, &repr.out, &repr.out_indices]
2279 }
2280 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2281 vec![&repr.x, &repr.out, &repr.indices, &repr.grad]
2282 }
2283 ModuleOperationIr::MaxPool2d(repr) => {
2284 vec![&repr.x, &repr.out]
2285 }
2286 ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2287 vec![&repr.x, &repr.out, &repr.out_indices]
2288 }
2289 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2290 vec![&repr.x, &repr.out, &repr.indices, &repr.grad]
2291 }
2292 ModuleOperationIr::Interpolate(repr) => {
2293 vec![&repr.x, &repr.out]
2294 }
2295 ModuleOperationIr::InterpolateBackward(repr) => {
2296 vec![&repr.x, &repr.out, &repr.grad]
2297 }
2298 }
2299 }
2300
2301 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2302 let mut output = Vec::new();
2303
2304 match self {
2305 ModuleOperationIr::Embedding(repr) => {
2306 repr.weights.mark_read_only(nodes, &mut output);
2307 repr.indices.mark_read_only(nodes, &mut output);
2308 }
2309 ModuleOperationIr::EmbeddingBackward(repr) => {
2310 repr.weights.mark_read_only(nodes, &mut output);
2311 repr.out_grad.mark_read_only(nodes, &mut output);
2312 repr.indices.mark_read_only(nodes, &mut output);
2313 }
2314 ModuleOperationIr::Conv1d(repr) => {
2315 repr.x.mark_read_only(nodes, &mut output);
2316 repr.weight.mark_read_only(nodes, &mut output);
2317
2318 if let Some(bias) = &mut repr.bias {
2319 bias.mark_read_only(nodes, &mut output);
2320 }
2321 }
2322 ModuleOperationIr::Conv2d(repr) => {
2323 repr.x.mark_read_only(nodes, &mut output);
2324 repr.weight.mark_read_only(nodes, &mut output);
2325
2326 if let Some(bias) = &mut repr.bias {
2327 bias.mark_read_only(nodes, &mut output);
2328 }
2329 }
2330 ModuleOperationIr::Conv3d(repr) => {
2331 repr.x.mark_read_only(nodes, &mut output);
2332 repr.weight.mark_read_only(nodes, &mut output);
2333
2334 if let Some(bias) = &mut repr.bias {
2335 bias.mark_read_only(nodes, &mut output);
2336 }
2337 }
2338 ModuleOperationIr::DeformableConv2d(repr) => {
2339 repr.x.mark_read_only(nodes, &mut output);
2340 repr.weight.mark_read_only(nodes, &mut output);
2341 repr.offset.mark_read_only(nodes, &mut output);
2342
2343 match (&mut repr.mask, &mut repr.bias) {
2344 (Some(mask), Some(bias)) => {
2345 mask.mark_read_only(nodes, &mut output);
2346 bias.mark_read_only(nodes, &mut output);
2347 }
2348 (Some(mask), None) => {
2349 mask.mark_read_only(nodes, &mut output);
2350 }
2351 (None, Some(bias)) => {
2352 bias.mark_read_only(nodes, &mut output);
2353 }
2354 (None, None) => {}
2355 };
2356 }
2357 ModuleOperationIr::DeformableConv2dBackward(repr) => {
2358 repr.x.mark_read_only(nodes, &mut output);
2359 repr.weight.mark_read_only(nodes, &mut output);
2360 repr.offset.mark_read_only(nodes, &mut output);
2361 repr.out_grad.mark_read_only(nodes, &mut output);
2362
2363 if let Some(mask) = repr.mask.as_mut() {
2364 mask.mark_read_only(nodes, &mut output);
2365 }
2366 if let Some(bias) = repr.bias.as_mut() {
2367 bias.mark_read_only(nodes, &mut output);
2368 }
2369 }
2370 ModuleOperationIr::ConvTranspose1d(repr) => {
2371 repr.x.mark_read_only(nodes, &mut output);
2372 repr.weight.mark_read_only(nodes, &mut output);
2373
2374 if let Some(bias) = &mut repr.bias {
2375 bias.mark_read_only(nodes, &mut output);
2376 }
2377 }
2378 ModuleOperationIr::ConvTranspose2d(repr) => {
2379 repr.x.mark_read_only(nodes, &mut output);
2380 repr.weight.mark_read_only(nodes, &mut output);
2381
2382 if let Some(bias) = &mut repr.bias {
2383 bias.mark_read_only(nodes, &mut output);
2384 }
2385 }
2386 ModuleOperationIr::ConvTranspose3d(repr) => {
2387 repr.x.mark_read_only(nodes, &mut output);
2388 repr.weight.mark_read_only(nodes, &mut output);
2389
2390 if let Some(bias) = &mut repr.bias {
2391 bias.mark_read_only(nodes, &mut output);
2392 }
2393 }
2394 ModuleOperationIr::AvgPool1d(repr) => {
2395 repr.x.mark_read_only(nodes, &mut output);
2396 }
2397 ModuleOperationIr::AvgPool2d(repr) => {
2398 repr.x.mark_read_only(nodes, &mut output);
2399 }
2400 ModuleOperationIr::AvgPool1dBackward(repr) => {
2401 repr.x.mark_read_only(nodes, &mut output);
2402 repr.grad.mark_read_only(nodes, &mut output);
2403 }
2404 ModuleOperationIr::AvgPool2dBackward(repr) => {
2405 repr.x.mark_read_only(nodes, &mut output);
2406 repr.grad.mark_read_only(nodes, &mut output);
2407 }
2408 ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
2409 repr.x.mark_read_only(nodes, &mut output);
2410 }
2411 ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
2412 repr.x.mark_read_only(nodes, &mut output);
2413 }
2414 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2415 repr.x.mark_read_only(nodes, &mut output);
2416 repr.grad.mark_read_only(nodes, &mut output);
2417 }
2418 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2419 repr.x.mark_read_only(nodes, &mut output);
2420 repr.grad.mark_read_only(nodes, &mut output);
2421 }
2422 ModuleOperationIr::MaxPool1d(repr) => {
2423 repr.x.mark_read_only(nodes, &mut output);
2424 }
2425 ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2426 repr.x.mark_read_only(nodes, &mut output);
2427 }
2428 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2429 repr.x.mark_read_only(nodes, &mut output);
2430 repr.grad.mark_read_only(nodes, &mut output);
2431 }
2432 ModuleOperationIr::MaxPool2d(repr) => {
2433 repr.x.mark_read_only(nodes, &mut output);
2434 }
2435 ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2436 repr.x.mark_read_only(nodes, &mut output);
2437 }
2438 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2439 repr.x.mark_read_only(nodes, &mut output);
2440 repr.grad.mark_read_only(nodes, &mut output);
2441 }
2442 ModuleOperationIr::Interpolate(repr) => {
2443 repr.x.mark_read_only(nodes, &mut output);
2444 }
2445 ModuleOperationIr::InterpolateBackward(repr) => {
2446 repr.x.mark_read_only(nodes, &mut output);
2447 repr.grad.mark_read_only(nodes, &mut output);
2448 }
2449 };
2450
2451 output
2452 }
2453}
2454
2455impl core::hash::Hash for InitOperationIr {
2456 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2457 self.out.hash(state);
2458 }
2459}
2460
2461impl InitOperationIr {
2462 fn nodes(&self) -> Vec<&TensorIr> {
2463 vec![&self.out]
2464 }
2465}
2466
2467impl TensorIr {
2468 fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {
2469 if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {
2470 output.push(self.clone());
2471 self.status = TensorStatus::ReadOnly;
2472 }
2473 }
2474}
2475
2476impl core::hash::Hash for RandomOpIr {
2477 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2478 self.out.hash(state);
2479
2480 match self.distribution {
2481 Distribution::Default => 1u8.hash(state),
2482 Distribution::Bernoulli(_) => 2u8.hash(state),
2483 Distribution::Uniform(_, _) => 3u8.hash(state),
2484 Distribution::Normal(_, _) => 4u8.hash(state),
2485 }
2486 }
2487}
2488
2489impl core::hash::Hash for ScalarOpIr {
2490 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2491 self.lhs.hash(state);
2492 self.out.hash(state);
2493 }
2494}
2495
2496impl core::hash::Hash for MaskFillOpIr {
2497 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2498 self.tensor.hash(state);
2499 self.mask.hash(state);
2500 self.out.hash(state);
2501 }
2502}
2503
2504impl core::hash::Hash for ClampOpIr {
2505 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2506 self.tensor.hash(state);
2507 self.out.hash(state);
2508 }
2509}
2510
2511impl core::hash::Hash for NumericOperationIr {
2512 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2513 match self {
2514 NumericOperationIr::Add(repr) => repr.hash(state),
2515 NumericOperationIr::AddScalar(repr) => repr.hash(state),
2516 NumericOperationIr::Sub(repr) => repr.hash(state),
2517 NumericOperationIr::SubScalar(repr) => repr.hash(state),
2518 NumericOperationIr::Div(repr) => repr.hash(state),
2519 NumericOperationIr::DivScalar(repr) => repr.hash(state),
2520 NumericOperationIr::Rem(repr) => repr.hash(state),
2521 NumericOperationIr::RemScalar(repr) => repr.hash(state),
2522 NumericOperationIr::Mul(repr) => repr.hash(state),
2523 NumericOperationIr::MulScalar(repr) => repr.hash(state),
2524 NumericOperationIr::Abs(repr) => repr.hash(state),
2525 NumericOperationIr::Ones(repr) => repr.hash(state),
2526 NumericOperationIr::Zeros(repr) => repr.hash(state),
2527 NumericOperationIr::Full(repr) => repr.0.hash(state),
2528 NumericOperationIr::Gather(repr) => repr.hash(state),
2529 NumericOperationIr::Scatter(repr) => repr.hash(state),
2530 NumericOperationIr::Select(repr) => repr.hash(state),
2531 NumericOperationIr::SelectAssign(repr) => repr.hash(state),
2532 NumericOperationIr::MaskWhere(repr) => repr.hash(state),
2533 NumericOperationIr::MaskFill(repr) => repr.hash(state),
2534 NumericOperationIr::MeanDim(repr) => repr.hash(state),
2535 NumericOperationIr::Mean(repr) => repr.hash(state),
2536 NumericOperationIr::Sum(repr) => repr.hash(state),
2537 NumericOperationIr::SumDim(repr) => repr.hash(state),
2538 NumericOperationIr::Prod(repr) => repr.hash(state),
2539 NumericOperationIr::ProdDim(repr) => repr.hash(state),
2540 NumericOperationIr::EqualElem(repr) => repr.hash(state),
2541 NumericOperationIr::Greater(repr) => repr.hash(state),
2542 NumericOperationIr::GreaterElem(repr) => repr.hash(state),
2543 NumericOperationIr::GreaterEqual(repr) => repr.hash(state),
2544 NumericOperationIr::GreaterEqualElem(repr) => repr.hash(state),
2545 NumericOperationIr::Lower(repr) => repr.hash(state),
2546 NumericOperationIr::LowerElem(repr) => repr.hash(state),
2547 NumericOperationIr::LowerEqual(repr) => repr.hash(state),
2548 NumericOperationIr::LowerEqualElem(repr) => repr.hash(state),
2549 NumericOperationIr::ArgMax(repr) => repr.hash(state),
2550 NumericOperationIr::ArgMin(repr) => repr.hash(state),
2551 NumericOperationIr::Max(repr) => repr.hash(state),
2552 NumericOperationIr::MaxDimWithIndices(repr) => repr.hash(state),
2553 NumericOperationIr::MinDimWithIndices(repr) => repr.hash(state),
2554 NumericOperationIr::Min(repr) => repr.hash(state),
2555 NumericOperationIr::MaxDim(repr) => repr.hash(state),
2556 NumericOperationIr::MinDim(repr) => repr.hash(state),
2557 NumericOperationIr::MaxAbs(repr) => repr.hash(state),
2558 NumericOperationIr::MaxAbsDim(repr) => repr.hash(state),
2559 NumericOperationIr::Clamp(repr) => repr.hash(state),
2560 NumericOperationIr::IntRandom(repr) => repr.hash(state),
2561 NumericOperationIr::Powf(repr) => repr.hash(state),
2562 }
2563 }
2564}