1use burn_backend::tensor::IndexingUpdateOp;
2use core::hash::Hash;
3use serde::{Deserialize, Serialize};
4
5use alloc::borrow::ToOwned;
6use alloc::boxed::Box;
7use alloc::{string::String, vec::Vec};
8
9use burn_backend::{
10 DType, Distribution, Slice,
11 ops::{
12 ConvOptions, ConvTransposeOptions, DeformConvOptions, GridSampleOptions,
13 GridSamplePaddingMode, InterpolateMode, InterpolateOptions,
14 },
15 quantization::QuantScheme,
16};
17
18use crate::{ScalarIr, TensorId, TensorIr, TensorStatus};
19
20#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
22pub struct CustomOpIr {
23 pub id: String,
25 pub inputs: Vec<TensorIr>,
27 pub outputs: Vec<TensorIr>,
29}
30
31impl CustomOpIr {
32 pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {
34 Self {
35 id: id.to_owned(),
36 inputs: inputs.to_vec(),
37 outputs: outputs.to_vec(),
38 }
39 }
40
41 pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(
43 &self,
44 ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {
45 (
46 self.inputs.as_slice().try_into().expect(
47 "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
48 ),
49 self.outputs.as_slice().try_into().expect(
50 "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
51 ),
52 )
53 }
54
55 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
56 Box::new(self.inputs.iter())
57 }
58
59 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
60 Box::new(self.outputs.iter())
61 }
62}
63
64#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
66pub enum OperationIr {
67 BaseFloat(BaseOperationIr),
69 BaseInt(BaseOperationIr),
71 BaseBool(BaseOperationIr),
73 NumericFloat(DType, NumericOperationIr),
75 NumericInt(DType, NumericOperationIr),
77 Bool(BoolOperationIr),
79 Int(IntOperationIr),
81 Float(DType, FloatOperationIr),
83 Module(ModuleOperationIr),
85 Init(InitOperationIr),
87 Custom(CustomOpIr),
89 Drop(TensorIr),
91}
92
93#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
95pub enum FloatOperationIr {
96 Exp(UnaryOpIr),
98 Log(UnaryOpIr),
100 Log1p(UnaryOpIr),
102 Erf(UnaryOpIr),
104 PowfScalar(ScalarOpIr),
106 Sqrt(UnaryOpIr),
108 Cos(UnaryOpIr),
110 Cosh(UnaryOpIr),
112 Sin(UnaryOpIr),
114 Sinh(UnaryOpIr),
116 Tan(UnaryOpIr),
118 Tanh(UnaryOpIr),
120 ArcCos(UnaryOpIr),
122 ArcCosh(UnaryOpIr),
124 ArcSin(UnaryOpIr),
126 ArcSinh(UnaryOpIr),
128 ArcTan(UnaryOpIr),
130 ArcTanh(UnaryOpIr),
132 ArcTan2(BinaryOpIr),
134 Round(UnaryOpIr),
136 Floor(UnaryOpIr),
138 Ceil(UnaryOpIr),
140 Trunc(UnaryOpIr),
142 IntoInt(CastOpIr),
144 Matmul(MatmulOpIr),
146 Cross(CrossOpIr),
148 Random(RandomOpIr),
150 Recip(UnaryOpIr),
152 IsNan(UnaryOpIr),
154 IsInf(UnaryOpIr),
156 Quantize(QuantizeOpIr),
158 Dequantize(DequantizeOpIr),
160 GridSample2d(GridSample2dOpIr),
162}
163
164#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
166pub enum ModuleOperationIr {
167 Embedding(EmbeddingOpIr),
169 EmbeddingBackward(EmbeddingBackwardOpIr),
171 Conv1d(Conv1dOpIr),
173 Conv2d(Conv2dOpIr),
175 Conv3d(Conv3dOpIr),
177 DeformableConv2d(Box<DeformConv2dOpIr>),
179 DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),
181 ConvTranspose1d(ConvTranspose1dOpIr),
183 ConvTranspose2d(ConvTranspose2dOpIr),
185 ConvTranspose3d(ConvTranspose3dOpIr),
187 AvgPool1d(AvgPool1dOpIr),
189 AvgPool2d(AvgPool2dOpIr),
191 AvgPool1dBackward(AvgPool1dBackwardOpIr),
194 AvgPool2dBackward(AvgPool2dBackwardOpIr),
197 AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),
200 AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),
203 AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),
206 AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),
209 MaxPool1d(MaxPool1dOpIr),
212 MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),
215 MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),
218 MaxPool2d(MaxPool2dOpIr),
221 MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),
224 MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),
227 Interpolate(InterpolateOpIr),
229 InterpolateBackward(InterpolateBackwardOpIr),
231}
232
233#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
235pub enum BaseOperationIr {
236 Reshape(ShapeOpIr),
242
243 SwapDims(SwapDimsOpIr),
249
250 Permute(PermuteOpIr),
256
257 Flip(FlipOpIr),
262
263 Expand(ShapeOpIr),
269
270 Unfold(UnfoldOpIr),
273
274 Slice(SliceOpIr),
280 SliceAssign(SliceAssignOpIr),
286 Select(SelectOpIr),
292 SelectAssign(SelectAssignOpIr),
298 MaskWhere(MaskWhereOpIr),
304 MaskFill(MaskFillOpIr),
310 Gather(GatherOpIr),
316 Scatter(ScatterOpIr),
322 Equal(BinaryOpIr),
328 EqualElem(ScalarOpIr),
334 RepeatDim(RepeatDimOpIr),
340 Cat(CatOpIr),
346 Cast(CastOpIr),
348 Empty(CreationOpIr),
354 Ones(CreationOpIr),
360 Zeros(CreationOpIr),
366}
367
368#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
370pub enum NumericOperationIr {
371 Add(BinaryOpIr),
376 AddScalar(ScalarOpIr),
381 Sub(BinaryOpIr),
386 SubScalar(ScalarOpIr),
391 Div(BinaryOpIr),
396 DivScalar(ScalarOpIr),
401 Rem(BinaryOpIr),
406 RemScalar(ScalarOpIr),
411 Mul(BinaryOpIr),
416 MulScalar(ScalarOpIr),
421 Abs(UnaryOpIr),
426 Full(FullOpIr),
431 MeanDim(ReduceDimOpIr),
436 Mean(ReduceOpIr),
441 Sum(ReduceOpIr),
446 SumDim(ReduceDimOpIr),
451 Prod(ReduceOpIr),
456 ProdDim(ReduceDimOpIr),
461 Greater(BinaryOpIr),
466 GreaterElem(ScalarOpIr),
471 GreaterEqual(BinaryOpIr),
476 GreaterEqualElem(ScalarOpIr),
481 Lower(BinaryOpIr),
486 LowerElem(ScalarOpIr),
491 LowerEqual(BinaryOpIr),
496 LowerEqualElem(ScalarOpIr),
501 ArgMax(ReduceDimOpIr),
506 ArgMin(ReduceDimOpIr),
511 Max(ReduceOpIr),
516 MaxDimWithIndices(ReduceDimWithIndicesOpIr),
521 MinDimWithIndices(ReduceDimWithIndicesOpIr),
526 Min(ReduceOpIr),
531 MaxDim(ReduceDimOpIr),
536 MinDim(ReduceDimOpIr),
541 MaxAbs(ReduceOpIr),
546 MaxAbsDim(ReduceDimOpIr),
551 Clamp(ClampOpIr),
556 IntRandom(RandomOpIr),
560 Powf(BinaryOpIr),
565 CumSum(DimOpIr),
570 CumProd(DimOpIr),
575 CumMin(DimOpIr),
580 CumMax(DimOpIr),
585}
586
587#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
589pub enum IntOperationIr {
590 IntoFloat(CastOpIr),
592 BitwiseAnd(BinaryOpIr),
596 BitwiseAndScalar(ScalarOpIr),
600 BitwiseOr(BinaryOpIr),
604 BitwiseOrScalar(ScalarOpIr),
608 BitwiseXor(BinaryOpIr),
612 BitwiseXorScalar(ScalarOpIr),
616 BitwiseNot(UnaryOpIr),
620 BitwiseLeftShift(BinaryOpIr),
624 BitwiseLeftShiftScalar(ScalarOpIr),
628 BitwiseRightShift(BinaryOpIr),
632 BitwiseRightShiftScalar(ScalarOpIr),
636 Matmul(MatmulOpIr),
638}
639
640#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
642pub enum BoolOperationIr {
643 IntoFloat(CastOpIr),
645 IntoInt(CastOpIr),
647 Not(UnaryOpIr),
649 And(BinaryOpIr),
651 Or(BinaryOpIr),
653}
654
655#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
657pub struct SwapDimsOpIr {
658 pub input: TensorIr,
660 pub out: TensorIr,
662 pub dim1: usize,
664 pub dim2: usize,
666}
667
668#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
670pub struct PermuteOpIr {
671 pub input: TensorIr,
673 pub out: TensorIr,
675 pub axes: Vec<usize>,
677}
678
679#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
681pub struct ShapeOpIr {
682 pub input: TensorIr,
684 pub out: TensorIr,
686}
687
688#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
690pub struct UnfoldOpIr {
691 pub input: TensorIr,
693 pub out: TensorIr,
695
696 pub dim: usize,
698 pub size: usize,
700 pub step: usize,
702}
703
704#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
706pub struct FlipOpIr {
707 pub input: TensorIr,
709 pub out: TensorIr,
711 pub axes: Vec<usize>,
713}
714
715#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
716#[allow(missing_docs)]
717pub struct RandomOpIr {
718 pub out: TensorIr,
719 pub distribution: Distribution,
720}
721
722#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
725pub struct CreationOpIr {
726 pub out: TensorIr,
728}
729
730#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
732pub struct FullOpIr {
733 pub out: TensorIr,
735 pub value: ScalarIr,
737}
738
739#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
740pub struct InitOperationIr {
744 pub out: TensorIr,
746}
747
748#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
749#[allow(missing_docs)]
750pub struct BinaryOpIr {
751 pub lhs: TensorIr,
752 pub rhs: TensorIr,
753 pub out: TensorIr,
754}
755
756#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
757#[allow(missing_docs)]
758pub struct MatmulOpIr {
759 pub lhs: TensorIr,
760 pub rhs: TensorIr,
761 pub out: TensorIr,
762}
763
764#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
765#[allow(missing_docs)]
766pub struct CrossOpIr {
767 pub lhs: TensorIr,
768 pub rhs: TensorIr,
769 pub out: TensorIr,
770 pub dim: usize,
771}
772
773#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
774#[allow(missing_docs)]
775pub struct UnaryOpIr {
776 pub input: TensorIr,
777 pub out: TensorIr,
778}
779
780#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
781#[allow(missing_docs)]
782pub struct ScalarOpIr {
783 pub lhs: TensorIr,
784 pub rhs: ScalarIr,
787 pub out: TensorIr,
788}
789
790#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
791#[allow(missing_docs)]
792pub struct ReduceOpIr {
793 pub input: TensorIr,
794 pub out: TensorIr,
795}
796
797#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
798#[allow(missing_docs)]
799pub struct ReduceDimOpIr {
800 pub input: TensorIr,
801 pub out: TensorIr,
802 pub axis: usize,
803}
804
805#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
806#[allow(missing_docs)]
807pub struct CastOpIr {
808 pub input: TensorIr,
809 pub out: TensorIr,
810}
811
812#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
815#[allow(missing_docs)]
816pub struct DimOpIr {
817 pub input: TensorIr,
818 pub out: TensorIr,
819 pub axis: usize,
820}
821
822#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
823#[allow(missing_docs)]
824pub struct GatherOpIr {
825 pub tensor: TensorIr,
826 pub dim: usize,
827 pub indices: TensorIr,
828 pub out: TensorIr,
829}
830
831#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
832#[allow(missing_docs)]
833pub struct ScatterOpIr {
834 pub tensor: TensorIr,
835 pub dim: usize,
836 pub indices: TensorIr,
837 pub value: TensorIr,
838 pub update: IndexingUpdateOp,
839 pub out: TensorIr,
840}
841
842#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
843#[allow(missing_docs)]
844pub struct SelectOpIr {
845 pub tensor: TensorIr,
846 pub dim: usize,
847 pub indices: TensorIr,
848 pub out: TensorIr,
849}
850
851#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
852#[allow(missing_docs)]
853pub struct SelectAssignOpIr {
854 pub tensor: TensorIr,
855 pub dim: usize,
856 pub indices: TensorIr,
857 pub value: TensorIr,
858 pub update: IndexingUpdateOp,
859 pub out: TensorIr,
860}
861
862#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
863#[allow(missing_docs)]
864pub struct SliceOpIr {
865 pub tensor: TensorIr,
866 pub ranges: Vec<Slice>,
867 pub out: TensorIr,
868}
869
870#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
871#[allow(missing_docs)]
872pub struct SliceAssignOpIr {
873 pub tensor: TensorIr,
874 pub ranges: Vec<burn_backend::Slice>,
875 pub value: TensorIr,
876 pub out: TensorIr,
877}
878
879#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
880#[allow(missing_docs)]
881pub struct MaskWhereOpIr {
882 pub tensor: TensorIr,
883 pub mask: TensorIr,
884 pub value: TensorIr,
885 pub out: TensorIr,
886}
887
888#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
889#[allow(missing_docs)]
890pub struct MaskFillOpIr {
891 pub tensor: TensorIr,
892 pub mask: TensorIr,
893 pub value: ScalarIr,
894 pub out: TensorIr,
895}
896
897#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
898#[allow(missing_docs)]
899pub struct ClampOpIr {
900 pub tensor: TensorIr,
901 pub min: ScalarIr,
902 pub max: ScalarIr,
903 pub out: TensorIr,
904}
905
906#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
907#[allow(missing_docs)]
908pub struct RepeatDimOpIr {
909 pub tensor: TensorIr,
910 pub dim: usize,
911 pub times: usize,
912 pub out: TensorIr,
913}
914
915#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
916#[allow(missing_docs)]
917pub struct CatOpIr {
918 pub tensors: Vec<TensorIr>,
919 pub dim: usize,
920 pub out: TensorIr,
921}
922
923#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
924#[allow(missing_docs)]
925pub struct ReduceDimWithIndicesOpIr {
926 pub tensor: TensorIr,
927 pub dim: usize,
928 pub out: TensorIr,
929 pub out_indices: TensorIr,
930}
931
932#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
933#[allow(missing_docs)]
934pub struct EmbeddingOpIr {
935 pub weights: TensorIr,
936 pub indices: TensorIr,
937 pub out: TensorIr,
938}
939
940#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
941#[allow(missing_docs)]
942pub struct EmbeddingBackwardOpIr {
943 pub weights: TensorIr,
944 pub out_grad: TensorIr,
945 pub indices: TensorIr,
946 pub out: TensorIr,
947}
948
949#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
950#[allow(missing_docs)]
951pub struct Conv1dOpIr {
952 pub x: TensorIr,
953 pub weight: TensorIr,
954 pub bias: Option<TensorIr>,
955 pub options: Conv1dOptionsIr,
956 pub out: TensorIr,
957}
958
959#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
960#[allow(missing_docs)]
961pub struct Conv2dOpIr {
962 pub x: TensorIr,
963 pub weight: TensorIr,
964 pub bias: Option<TensorIr>,
965 pub options: Conv2dOptionsIr,
966 pub out: TensorIr,
967}
968
969#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
970#[allow(missing_docs)]
971pub struct DeformConv2dOpIr {
972 pub x: TensorIr,
973 pub offset: TensorIr,
974 pub weight: TensorIr,
975 pub mask: Option<TensorIr>,
976 pub bias: Option<TensorIr>,
977 pub options: DeformableConv2dOptionsIr,
978 pub out: TensorIr,
979}
980
981#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
982#[allow(missing_docs)]
983pub struct DeformConv2dBackwardOpIr {
984 pub x: TensorIr,
985 pub offset: TensorIr,
986 pub weight: TensorIr,
987 pub mask: Option<TensorIr>,
988 pub bias: Option<TensorIr>,
989 pub out_grad: TensorIr,
990 pub options: DeformableConv2dOptionsIr,
991 pub input_grad: TensorIr,
992 pub offset_grad: TensorIr,
993 pub weight_grad: TensorIr,
994 pub mask_grad: Option<TensorIr>,
995 pub bias_grad: Option<TensorIr>,
996}
997
998#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
999#[allow(missing_docs)]
1000pub struct Conv3dOpIr {
1001 pub x: TensorIr,
1002 pub weight: TensorIr,
1003 pub bias: Option<TensorIr>,
1004 pub options: Conv3dOptionsIr,
1005 pub out: TensorIr,
1006}
1007
1008#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1009#[allow(missing_docs)]
1010pub struct ConvTranspose1dOpIr {
1011 pub x: TensorIr,
1012 pub weight: TensorIr,
1013 pub bias: Option<TensorIr>,
1014 pub options: ConvTranspose1dOptionsIr,
1015 pub out: TensorIr,
1016}
1017
1018#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1019#[allow(missing_docs)]
1020pub struct ConvTranspose2dOpIr {
1021 pub x: TensorIr,
1022 pub weight: TensorIr,
1023 pub bias: Option<TensorIr>,
1024 pub options: ConvTranspose2dOptionsIr,
1025 pub out: TensorIr,
1026}
1027
1028#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1029#[allow(missing_docs)]
1030pub struct ConvTranspose3dOpIr {
1031 pub x: TensorIr,
1032 pub weight: TensorIr,
1033 pub bias: Option<TensorIr>,
1034 pub options: ConvTranspose3dOptionsIr,
1035 pub out: TensorIr,
1036}
1037
1038#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1039#[allow(missing_docs)]
1040pub struct Conv1dOptionsIr {
1041 pub stride: [usize; 1],
1042 pub padding: [usize; 1],
1043 pub dilation: [usize; 1],
1044 pub groups: usize,
1045}
1046
1047#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1048#[allow(missing_docs)]
1049pub struct Conv2dOptionsIr {
1050 pub stride: [usize; 2],
1051 pub padding: [usize; 2],
1052 pub dilation: [usize; 2],
1053 pub groups: usize,
1054}
1055
1056#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1057#[allow(missing_docs)]
1058pub struct DeformableConv2dOptionsIr {
1059 pub stride: [usize; 2],
1060 pub padding: [usize; 2],
1061 pub dilation: [usize; 2],
1062 pub weight_groups: usize,
1063 pub offset_groups: usize,
1064}
1065
1066#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1067#[allow(missing_docs)]
1068pub struct Conv3dOptionsIr {
1069 pub stride: [usize; 3],
1070 pub padding: [usize; 3],
1071 pub dilation: [usize; 3],
1072 pub groups: usize,
1073}
1074
1075#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1076#[allow(missing_docs)]
1077pub struct ConvTranspose1dOptionsIr {
1078 pub stride: [usize; 1],
1079 pub padding: [usize; 1],
1080 pub padding_out: [usize; 1],
1081 pub dilation: [usize; 1],
1082 pub groups: usize,
1083}
1084
1085#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1086#[allow(missing_docs)]
1087pub struct ConvTranspose2dOptionsIr {
1088 pub stride: [usize; 2],
1089 pub padding: [usize; 2],
1090 pub padding_out: [usize; 2],
1091 pub dilation: [usize; 2],
1092 pub groups: usize,
1093}
1094
1095#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1096#[allow(missing_docs)]
1097pub struct ConvTranspose3dOptionsIr {
1098 pub stride: [usize; 3],
1099 pub padding: [usize; 3],
1100 pub padding_out: [usize; 3],
1101 pub dilation: [usize; 3],
1102 pub groups: usize,
1103}
1104
1105#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
1107pub struct QuantizationParametersIr {
1108 pub scales: TensorIr,
1110}
1111
1112#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1113#[allow(missing_docs)]
1114pub struct QuantizeOpIr {
1115 pub tensor: TensorIr,
1116 pub qparams: QuantizationParametersIr,
1117 pub scheme: QuantScheme,
1118 pub out: TensorIr,
1119}
1120
1121#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1122#[allow(missing_docs)]
1123pub struct DequantizeOpIr {
1124 pub input: TensorIr,
1125 pub out: TensorIr,
1126}
1127
1128impl From<ConvOptions<1>> for Conv1dOptionsIr {
1129 fn from(value: ConvOptions<1>) -> Self {
1130 Self {
1131 stride: value.stride,
1132 padding: value.padding,
1133 dilation: value.dilation,
1134 groups: value.groups,
1135 }
1136 }
1137}
1138
1139impl From<ConvOptions<2>> for Conv2dOptionsIr {
1140 fn from(value: ConvOptions<2>) -> Self {
1141 Self {
1142 stride: value.stride,
1143 padding: value.padding,
1144 dilation: value.dilation,
1145 groups: value.groups,
1146 }
1147 }
1148}
1149
1150impl From<ConvOptions<3>> for Conv3dOptionsIr {
1151 fn from(value: ConvOptions<3>) -> Self {
1152 Self {
1153 stride: value.stride,
1154 padding: value.padding,
1155 dilation: value.dilation,
1156 groups: value.groups,
1157 }
1158 }
1159}
1160
1161impl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {
1162 fn from(value: DeformConvOptions<2>) -> Self {
1163 Self {
1164 stride: value.stride,
1165 padding: value.padding,
1166 dilation: value.dilation,
1167 weight_groups: value.weight_groups,
1168 offset_groups: value.offset_groups,
1169 }
1170 }
1171}
1172
1173impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {
1174 fn from(value: ConvTransposeOptions<1>) -> Self {
1175 Self {
1176 stride: value.stride,
1177 padding: value.padding,
1178 padding_out: value.padding_out,
1179 dilation: value.dilation,
1180 groups: value.groups,
1181 }
1182 }
1183}
1184
1185impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {
1186 fn from(value: ConvTransposeOptions<2>) -> Self {
1187 Self {
1188 stride: value.stride,
1189 padding: value.padding,
1190 padding_out: value.padding_out,
1191 dilation: value.dilation,
1192 groups: value.groups,
1193 }
1194 }
1195}
1196
1197impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {
1198 fn from(value: ConvTransposeOptions<3>) -> Self {
1199 Self {
1200 stride: value.stride,
1201 padding: value.padding,
1202 padding_out: value.padding_out,
1203 dilation: value.dilation,
1204 groups: value.groups,
1205 }
1206 }
1207}
1208
1209impl From<Conv1dOptionsIr> for ConvOptions<1> {
1210 fn from(val: Conv1dOptionsIr) -> Self {
1211 ConvOptions {
1212 stride: val.stride,
1213 padding: val.padding,
1214 dilation: val.dilation,
1215 groups: val.groups,
1216 }
1217 }
1218}
1219
1220impl From<Conv2dOptionsIr> for ConvOptions<2> {
1221 fn from(val: Conv2dOptionsIr) -> Self {
1222 ConvOptions {
1223 stride: val.stride,
1224 padding: val.padding,
1225 dilation: val.dilation,
1226 groups: val.groups,
1227 }
1228 }
1229}
1230
1231impl From<Conv3dOptionsIr> for ConvOptions<3> {
1232 fn from(val: Conv3dOptionsIr) -> Self {
1233 ConvOptions {
1234 stride: val.stride,
1235 padding: val.padding,
1236 dilation: val.dilation,
1237 groups: val.groups,
1238 }
1239 }
1240}
1241
1242impl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {
1243 fn from(value: DeformableConv2dOptionsIr) -> Self {
1244 DeformConvOptions {
1245 stride: value.stride,
1246 padding: value.padding,
1247 dilation: value.dilation,
1248 weight_groups: value.weight_groups,
1249 offset_groups: value.offset_groups,
1250 }
1251 }
1252}
1253
1254impl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {
1255 fn from(val: ConvTranspose1dOptionsIr) -> Self {
1256 ConvTransposeOptions {
1257 stride: val.stride,
1258 padding: val.padding,
1259 padding_out: val.padding_out,
1260 dilation: val.dilation,
1261 groups: val.groups,
1262 }
1263 }
1264}
1265
1266impl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {
1267 fn from(val: ConvTranspose2dOptionsIr) -> Self {
1268 ConvTransposeOptions {
1269 stride: val.stride,
1270 padding: val.padding,
1271 padding_out: val.padding_out,
1272 dilation: val.dilation,
1273 groups: val.groups,
1274 }
1275 }
1276}
1277
1278impl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {
1279 fn from(val: ConvTranspose3dOptionsIr) -> Self {
1280 ConvTransposeOptions {
1281 stride: val.stride,
1282 padding: val.padding,
1283 padding_out: val.padding_out,
1284 dilation: val.dilation,
1285 groups: val.groups,
1286 }
1287 }
1288}
1289
1290#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1291#[allow(missing_docs)]
1292pub struct AvgPool1dOpIr {
1293 pub x: TensorIr,
1294 pub kernel_size: usize,
1295 pub stride: usize,
1296 pub padding: usize,
1297 pub count_include_pad: bool,
1298 pub ceil_mode: bool,
1299 pub out: TensorIr,
1300}
1301
1302#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1303#[allow(missing_docs)]
1304pub struct AvgPool2dOpIr {
1305 pub x: TensorIr,
1306 pub kernel_size: [usize; 2],
1307 pub stride: [usize; 2],
1308 pub padding: [usize; 2],
1309 pub count_include_pad: bool,
1310 pub ceil_mode: bool,
1311 pub out: TensorIr,
1312}
1313
1314#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1315#[allow(missing_docs)]
1316pub struct AvgPool1dBackwardOpIr {
1317 pub x: TensorIr,
1318 pub grad: TensorIr,
1319 pub kernel_size: usize,
1320 pub stride: usize,
1321 pub padding: usize,
1322 pub count_include_pad: bool,
1323 pub ceil_mode: bool,
1324 pub out: TensorIr,
1325}
1326
1327#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1328#[allow(missing_docs)]
1329pub struct AvgPool2dBackwardOpIr {
1330 pub x: TensorIr,
1331 pub grad: TensorIr,
1332 pub kernel_size: [usize; 2],
1333 pub stride: [usize; 2],
1334 pub padding: [usize; 2],
1335 pub count_include_pad: bool,
1336 pub ceil_mode: bool,
1337 pub out: TensorIr,
1338}
1339
1340#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1341#[allow(missing_docs)]
1342pub struct AdaptiveAvgPool1dOpIr {
1343 pub x: TensorIr,
1344 pub output_size: usize,
1345 pub out: TensorIr,
1346}
1347
1348#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1349#[allow(missing_docs)]
1350pub struct AdaptiveAvgPool2dOpIr {
1351 pub x: TensorIr,
1352 pub output_size: [usize; 2],
1353 pub out: TensorIr,
1354}
1355
1356#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1357#[allow(missing_docs)]
1358pub struct AdaptiveAvgPool1dBackwardOpIr {
1359 pub x: TensorIr,
1360 pub grad: TensorIr,
1361 pub out: TensorIr,
1362}
1363
1364#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1365#[allow(missing_docs)]
1366pub struct AdaptiveAvgPool2dBackwardOpIr {
1367 pub x: TensorIr,
1368 pub grad: TensorIr,
1369 pub out: TensorIr,
1370}
1371
1372#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1373#[allow(missing_docs)]
1374pub struct MaxPool1dOpIr {
1375 pub x: TensorIr,
1376 pub kernel_size: usize,
1377 pub stride: usize,
1378 pub padding: usize,
1379 pub dilation: usize,
1380 pub ceil_mode: bool,
1381 pub out: TensorIr,
1382}
1383
1384#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1385#[allow(missing_docs)]
1386pub struct MaxPool1dWithIndicesOpIr {
1387 pub x: TensorIr,
1388 pub kernel_size: usize,
1389 pub stride: usize,
1390 pub padding: usize,
1391 pub dilation: usize,
1392 pub ceil_mode: bool,
1393 pub out: TensorIr,
1394 pub out_indices: TensorIr,
1395}
1396
1397#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1398#[allow(missing_docs)]
1399pub struct MaxPool1dWithIndicesBackwardOpIr {
1400 pub x: TensorIr,
1401 pub grad: TensorIr,
1402 pub indices: TensorIr,
1403 pub kernel_size: usize,
1404 pub stride: usize,
1405 pub padding: usize,
1406 pub dilation: usize,
1407 pub ceil_mode: bool,
1408 pub out: TensorIr,
1409}
1410
1411#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1412#[allow(missing_docs)]
1413pub struct MaxPool2dOpIr {
1414 pub x: TensorIr,
1415 pub kernel_size: [usize; 2],
1416 pub stride: [usize; 2],
1417 pub padding: [usize; 2],
1418 pub dilation: [usize; 2],
1419 pub ceil_mode: bool,
1420 pub out: TensorIr,
1421}
1422
1423#[allow(missing_docs)]
1424#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1425pub struct MaxPool2dWithIndicesOpIr {
1426 pub x: TensorIr,
1427 pub kernel_size: [usize; 2],
1428 pub stride: [usize; 2],
1429 pub padding: [usize; 2],
1430 pub dilation: [usize; 2],
1431 pub ceil_mode: bool,
1432 pub out: TensorIr,
1433 pub out_indices: TensorIr,
1434}
1435
1436#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1437#[allow(missing_docs)]
1438pub struct MaxPool2dWithIndicesBackwardOpIr {
1439 pub x: TensorIr,
1440 pub grad: TensorIr,
1441 pub indices: TensorIr,
1442 pub kernel_size: [usize; 2],
1443 pub stride: [usize; 2],
1444 pub padding: [usize; 2],
1445 pub dilation: [usize; 2],
1446 pub ceil_mode: bool,
1447 pub out: TensorIr,
1448}
1449
1450#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1451#[allow(missing_docs)]
1452pub enum InterpolateModeIr {
1453 Nearest,
1454 Bilinear,
1455 Bicubic,
1456}
1457
1458#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1459#[allow(missing_docs)]
1460pub struct InterpolateOptionsIr {
1461 pub mode: InterpolateModeIr,
1462}
1463
1464#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1465#[allow(missing_docs)]
1466pub struct InterpolateOpIr {
1467 pub x: TensorIr,
1468 pub output_size: [usize; 2],
1469 pub options: InterpolateOptionsIr,
1470 pub out: TensorIr,
1471}
1472
1473impl From<InterpolateModeIr> for InterpolateMode {
1474 fn from(val: InterpolateModeIr) -> Self {
1475 match val {
1476 InterpolateModeIr::Nearest => Self::Nearest,
1477 InterpolateModeIr::Bilinear => Self::Bilinear,
1478 InterpolateModeIr::Bicubic => Self::Bicubic,
1479 }
1480 }
1481}
1482
1483impl From<InterpolateOptionsIr> for InterpolateOptions {
1484 fn from(val: InterpolateOptionsIr) -> Self {
1485 Self {
1486 mode: val.mode.into(),
1487 }
1488 }
1489}
1490
1491impl From<InterpolateMode> for InterpolateModeIr {
1492 fn from(val: InterpolateMode) -> Self {
1493 match val {
1494 InterpolateMode::Nearest => Self::Nearest,
1495 InterpolateMode::Bilinear => Self::Bilinear,
1496 InterpolateMode::Bicubic => Self::Bicubic,
1497 }
1498 }
1499}
1500
1501impl From<InterpolateOptions> for InterpolateOptionsIr {
1502 fn from(val: InterpolateOptions) -> Self {
1503 Self {
1504 mode: val.mode.into(),
1505 }
1506 }
1507}
1508
1509#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1510#[allow(missing_docs)]
1511pub struct InterpolateBackwardOpIr {
1512 pub x: TensorIr,
1513 pub grad: TensorIr,
1514 pub output_size: [usize; 2],
1515 pub options: InterpolateOptionsIr,
1516 pub out: TensorIr,
1517}
1518
1519#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1520#[allow(missing_docs)]
1521pub enum GridSamplePaddingModeIr {
1522 Zeros,
1523 Border,
1524 Reflection,
1525}
1526
1527#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1528#[allow(missing_docs)]
1529pub struct GridSampleOptionsIr {
1530 pub mode: InterpolateModeIr,
1531 pub padding_mode: GridSamplePaddingModeIr,
1532 pub align_corners: bool,
1533}
1534
1535#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1536#[allow(missing_docs)]
1537pub struct GridSample2dOpIr {
1538 pub tensor: TensorIr,
1539 pub grid: TensorIr,
1540 pub options: GridSampleOptionsIr,
1541 pub out: TensorIr,
1542}
1543
1544impl From<GridSamplePaddingModeIr> for GridSamplePaddingMode {
1545 fn from(val: GridSamplePaddingModeIr) -> Self {
1546 match val {
1547 GridSamplePaddingModeIr::Zeros => Self::Zeros,
1548 GridSamplePaddingModeIr::Border => Self::Border,
1549 GridSamplePaddingModeIr::Reflection => Self::Reflection,
1550 }
1551 }
1552}
1553
1554impl From<GridSamplePaddingMode> for GridSamplePaddingModeIr {
1555 fn from(val: GridSamplePaddingMode) -> Self {
1556 match val {
1557 GridSamplePaddingMode::Zeros => Self::Zeros,
1558 GridSamplePaddingMode::Border => Self::Border,
1559 GridSamplePaddingMode::Reflection => Self::Reflection,
1560 }
1561 }
1562}
1563
1564impl From<GridSampleOptionsIr> for GridSampleOptions {
1565 fn from(val: GridSampleOptionsIr) -> Self {
1566 Self {
1567 mode: val.mode.into(),
1568 padding_mode: val.padding_mode.into(),
1569 align_corners: val.align_corners,
1570 }
1571 }
1572}
1573
1574impl From<GridSampleOptions> for GridSampleOptionsIr {
1575 fn from(val: GridSampleOptions) -> Self {
1576 Self {
1577 mode: val.mode.into(),
1578 padding_mode: val.padding_mode.into(),
1579 align_corners: val.align_corners,
1580 }
1581 }
1582}
1583
1584impl OperationIr {
1585 pub fn inputs(&self) -> impl Iterator<Item = &TensorIr> {
1587 match self {
1588 OperationIr::BaseFloat(repr) => repr.inputs(),
1589 OperationIr::BaseInt(repr) => repr.inputs(),
1590 OperationIr::BaseBool(repr) => repr.inputs(),
1591 OperationIr::NumericFloat(_dtype, repr) => repr.inputs(),
1592 OperationIr::NumericInt(_dtype, repr) => repr.inputs(),
1593 OperationIr::Bool(repr) => repr.inputs(),
1594 OperationIr::Int(repr) => repr.inputs(),
1595 OperationIr::Float(_dtype, repr) => repr.inputs(),
1596 OperationIr::Module(repr) => repr.inputs(),
1597 OperationIr::Init(repr) => repr.inputs(),
1598 OperationIr::Custom(repr) => repr.inputs(),
1599 OperationIr::Drop(repr) => Box::new([repr].into_iter()),
1600 }
1601 }
1602
1603 pub fn outputs(&self) -> impl Iterator<Item = &TensorIr> {
1605 match self {
1606 OperationIr::BaseFloat(repr) => repr.outputs(),
1607 OperationIr::BaseInt(repr) => repr.outputs(),
1608 OperationIr::BaseBool(repr) => repr.outputs(),
1609 OperationIr::NumericFloat(_dtype, repr) => repr.outputs(),
1610 OperationIr::NumericInt(_dtype, repr) => repr.outputs(),
1611 OperationIr::Bool(repr) => repr.outputs(),
1612 OperationIr::Int(repr) => repr.outputs(),
1613 OperationIr::Float(_dtype, repr) => repr.outputs(),
1614 OperationIr::Module(repr) => repr.outputs(),
1615 OperationIr::Init(repr) => repr.outputs(),
1616 OperationIr::Custom(repr) => repr.outputs(),
1617 OperationIr::Drop(_repr) => Box::new([].into_iter()),
1618 }
1619 }
1620
1621 pub fn nodes(&self) -> Vec<&TensorIr> {
1623 self.inputs().chain(self.outputs()).collect()
1624 }
1625
1626 pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1631 match self {
1632 OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),
1633 OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),
1634 OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),
1635 OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),
1636 OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),
1637 OperationIr::Bool(repr) => repr.mark_read_only(nodes),
1638 OperationIr::Int(repr) => repr.mark_read_only(nodes),
1639 OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),
1640 OperationIr::Module(repr) => repr.mark_read_only(nodes),
1641 OperationIr::Init(_) => Vec::new(),
1642 OperationIr::Drop(repr) => {
1643 let mut output = Vec::new();
1644 repr.mark_read_only(nodes, &mut output);
1645 output
1646 }
1647 OperationIr::Custom(repr) => {
1648 let mut output = Vec::new();
1649
1650 for input in repr.inputs.iter_mut() {
1651 input.mark_read_only(nodes, &mut output);
1652 }
1653
1654 output
1655 }
1656 }
1657 }
1658}
1659
1660impl BaseOperationIr {
1661 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1662 match self {
1663 BaseOperationIr::Reshape(repr) => Box::new([&repr.input].into_iter()),
1664 BaseOperationIr::SwapDims(repr) => Box::new([&repr.input].into_iter()),
1665 BaseOperationIr::Permute(repr) => Box::new([&repr.input].into_iter()),
1666 BaseOperationIr::Expand(repr) => Box::new([&repr.input].into_iter()),
1667 BaseOperationIr::Flip(repr) => Box::new([&repr.input].into_iter()),
1668 BaseOperationIr::Slice(repr) => Box::new([&repr.tensor].into_iter()),
1669 BaseOperationIr::SliceAssign(repr) => Box::new([&repr.tensor, &repr.value].into_iter()),
1670 BaseOperationIr::Gather(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
1671 BaseOperationIr::Scatter(repr) => {
1672 Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
1673 }
1674 BaseOperationIr::Select(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
1675 BaseOperationIr::SelectAssign(repr) => {
1676 Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
1677 }
1678 BaseOperationIr::MaskWhere(repr) => {
1679 Box::new([&repr.tensor, &repr.mask, &repr.value].into_iter())
1680 }
1681 BaseOperationIr::MaskFill(repr) => Box::new([&repr.tensor, &repr.mask].into_iter()),
1682 BaseOperationIr::Equal(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1683 BaseOperationIr::EqualElem(repr) => Box::new([&repr.lhs].into_iter()),
1684 BaseOperationIr::RepeatDim(repr) => Box::new([&repr.tensor].into_iter()),
1685 BaseOperationIr::Cat(repr) => Box::new(repr.tensors.iter()),
1686 BaseOperationIr::Cast(repr) => Box::new([&repr.input].into_iter()),
1687 BaseOperationIr::Unfold(repr) => Box::new([&repr.input].into_iter()),
1688 BaseOperationIr::Empty(_repr) => Box::new([].into_iter()),
1689 BaseOperationIr::Ones(_repr) => Box::new([].into_iter()),
1690 BaseOperationIr::Zeros(_repr) => Box::new([].into_iter()),
1691 }
1692 }
1693
1694 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1695 match self {
1696 BaseOperationIr::Reshape(repr) => Box::new([&repr.out].into_iter()),
1697 BaseOperationIr::SwapDims(repr) => Box::new([&repr.out].into_iter()),
1698 BaseOperationIr::Permute(repr) => Box::new([&repr.out].into_iter()),
1699 BaseOperationIr::Expand(repr) => Box::new([&repr.out].into_iter()),
1700 BaseOperationIr::Flip(repr) => Box::new([&repr.out].into_iter()),
1701 BaseOperationIr::Slice(repr) => Box::new([&repr.out].into_iter()),
1702 BaseOperationIr::SliceAssign(repr) => Box::new([&repr.out].into_iter()),
1703 BaseOperationIr::Gather(repr) => Box::new([&repr.out].into_iter()),
1704 BaseOperationIr::Scatter(repr) => Box::new([&repr.out].into_iter()),
1705 BaseOperationIr::Select(repr) => Box::new([&repr.out].into_iter()),
1706 BaseOperationIr::SelectAssign(repr) => Box::new([&repr.out].into_iter()),
1707 BaseOperationIr::MaskWhere(repr) => Box::new([&repr.out].into_iter()),
1708 BaseOperationIr::MaskFill(repr) => Box::new([&repr.out].into_iter()),
1709 BaseOperationIr::Equal(repr) => Box::new([&repr.out].into_iter()),
1710 BaseOperationIr::EqualElem(repr) => Box::new([&repr.out].into_iter()),
1711 BaseOperationIr::RepeatDim(repr) => Box::new([&repr.out].into_iter()),
1712 BaseOperationIr::Cat(repr) => Box::new([&repr.out].into_iter()),
1713 BaseOperationIr::Cast(repr) => Box::new([&repr.out].into_iter()),
1714 BaseOperationIr::Unfold(repr) => Box::new([&repr.out].into_iter()),
1715 BaseOperationIr::Empty(repr) => Box::new([&repr.out].into_iter()),
1716 BaseOperationIr::Ones(repr) => Box::new([&repr.out].into_iter()),
1717 BaseOperationIr::Zeros(repr) => Box::new([&repr.out].into_iter()),
1718 }
1719 }
1720
1721 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1722 let mut output = Vec::new();
1723
1724 match self {
1725 BaseOperationIr::Reshape(repr) => {
1726 repr.input.mark_read_only(nodes, &mut output);
1727 }
1728 BaseOperationIr::SwapDims(repr) => {
1729 repr.input.mark_read_only(nodes, &mut output);
1730 }
1731 BaseOperationIr::Permute(repr) => {
1732 repr.input.mark_read_only(nodes, &mut output);
1733 }
1734
1735 BaseOperationIr::Expand(repr) => {
1736 repr.input.mark_read_only(nodes, &mut output);
1737 }
1738
1739 BaseOperationIr::Flip(repr) => {
1740 repr.input.mark_read_only(nodes, &mut output);
1741 }
1742 BaseOperationIr::Slice(repr) => {
1743 repr.tensor.mark_read_only(nodes, &mut output);
1744 }
1745 BaseOperationIr::SliceAssign(repr) => {
1746 repr.tensor.mark_read_only(nodes, &mut output);
1747 repr.value.mark_read_only(nodes, &mut output);
1748 }
1749 BaseOperationIr::Gather(repr) => {
1750 repr.tensor.mark_read_only(nodes, &mut output);
1751 repr.indices.mark_read_only(nodes, &mut output);
1752 }
1753 BaseOperationIr::Scatter(repr) => {
1754 repr.tensor.mark_read_only(nodes, &mut output);
1755 repr.indices.mark_read_only(nodes, &mut output);
1756 repr.value.mark_read_only(nodes, &mut output);
1757 }
1758 BaseOperationIr::Select(repr) => {
1759 repr.tensor.mark_read_only(nodes, &mut output);
1760 repr.indices.mark_read_only(nodes, &mut output);
1761 }
1762 BaseOperationIr::SelectAssign(repr) => {
1763 repr.tensor.mark_read_only(nodes, &mut output);
1764 repr.indices.mark_read_only(nodes, &mut output);
1765 repr.value.mark_read_only(nodes, &mut output);
1766 }
1767 BaseOperationIr::MaskWhere(repr) => {
1768 repr.tensor.mark_read_only(nodes, &mut output);
1769 repr.mask.mark_read_only(nodes, &mut output);
1770 repr.value.mark_read_only(nodes, &mut output);
1771 }
1772 BaseOperationIr::MaskFill(repr) => {
1773 repr.tensor.mark_read_only(nodes, &mut output);
1774 repr.mask.mark_read_only(nodes, &mut output);
1775 }
1776 BaseOperationIr::Equal(repr) => {
1777 repr.lhs.mark_read_only(nodes, &mut output);
1778 repr.rhs.mark_read_only(nodes, &mut output);
1779 }
1780 BaseOperationIr::EqualElem(repr) => {
1781 repr.lhs.mark_read_only(nodes, &mut output);
1782 }
1783 BaseOperationIr::RepeatDim(repr) => {
1784 repr.tensor.mark_read_only(nodes, &mut output);
1785 }
1786 BaseOperationIr::Cat(repr) => {
1787 for t in repr.tensors.iter_mut() {
1788 t.mark_read_only(nodes, &mut output);
1789 }
1790 }
1791 BaseOperationIr::Cast(repr) => {
1792 repr.input.mark_read_only(nodes, &mut output);
1793 }
1794 BaseOperationIr::Unfold(repr) => {
1795 repr.input.mark_read_only(nodes, &mut output);
1796 }
1797 BaseOperationIr::Empty(_) => {}
1798 BaseOperationIr::Zeros(_) => {}
1799 BaseOperationIr::Ones(_) => {}
1800 };
1801
1802 output
1803 }
1804}
1805
1806impl NumericOperationIr {
1807 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1808 match self {
1809 NumericOperationIr::Add(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1810 NumericOperationIr::AddScalar(repr) => Box::new([&repr.lhs].into_iter()),
1811 NumericOperationIr::Sub(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1812 NumericOperationIr::SubScalar(repr) => Box::new([&repr.lhs].into_iter()),
1813 NumericOperationIr::Mul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1814 NumericOperationIr::MulScalar(repr) => Box::new([&repr.lhs].into_iter()),
1815 NumericOperationIr::Div(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1816 NumericOperationIr::DivScalar(repr) => Box::new([&repr.lhs].into_iter()),
1817 NumericOperationIr::Rem(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1818 NumericOperationIr::RemScalar(repr) => Box::new([&repr.lhs].into_iter()),
1819 NumericOperationIr::GreaterElem(repr) => Box::new([&repr.lhs].into_iter()),
1820 NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
1821 NumericOperationIr::LowerElem(repr) => Box::new([&repr.lhs].into_iter()),
1822 NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
1823 NumericOperationIr::Greater(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1824 NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1825 NumericOperationIr::Lower(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1826 NumericOperationIr::LowerEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1827 NumericOperationIr::ArgMax(repr) => Box::new([&repr.input].into_iter()),
1828 NumericOperationIr::ArgMin(repr) => Box::new([&repr.input].into_iter()),
1829 NumericOperationIr::Clamp(repr) => Box::new([&repr.tensor].into_iter()),
1830 NumericOperationIr::Abs(repr) => Box::new([&repr.input].into_iter()),
1831 NumericOperationIr::Full(_repr) => Box::new([].into_iter()),
1832 NumericOperationIr::MeanDim(repr) => Box::new([&repr.input].into_iter()),
1833 NumericOperationIr::Mean(repr) => Box::new([&repr.input].into_iter()),
1834 NumericOperationIr::Sum(repr) => Box::new([&repr.input].into_iter()),
1835 NumericOperationIr::SumDim(repr) => Box::new([&repr.input].into_iter()),
1836 NumericOperationIr::Prod(repr) => Box::new([&repr.input].into_iter()),
1837 NumericOperationIr::ProdDim(repr) => Box::new([&repr.input].into_iter()),
1838 NumericOperationIr::Max(repr) => Box::new([&repr.input].into_iter()),
1839 NumericOperationIr::MaxDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
1840 NumericOperationIr::MinDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
1841 NumericOperationIr::Min(repr) => Box::new([&repr.input].into_iter()),
1842 NumericOperationIr::MaxDim(repr) => Box::new([&repr.input].into_iter()),
1843 NumericOperationIr::MinDim(repr) => Box::new([&repr.input].into_iter()),
1844 NumericOperationIr::MaxAbs(repr) => Box::new([&repr.input].into_iter()),
1845 NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.input].into_iter()),
1846 NumericOperationIr::IntRandom(_repr) => Box::new([].into_iter()),
1847 NumericOperationIr::Powf(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1848 NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
1849 NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
1850 NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
1851 NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
1852 }
1853 }
1854
1855 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1856 match self {
1857 NumericOperationIr::Add(repr) => Box::new([&repr.out].into_iter()),
1858 NumericOperationIr::AddScalar(repr) => Box::new([&repr.out].into_iter()),
1859 NumericOperationIr::Sub(repr) => Box::new([&repr.out].into_iter()),
1860 NumericOperationIr::SubScalar(repr) => Box::new([&repr.out].into_iter()),
1861 NumericOperationIr::Mul(repr) => Box::new([&repr.out].into_iter()),
1862 NumericOperationIr::MulScalar(repr) => Box::new([&repr.out].into_iter()),
1863 NumericOperationIr::Div(repr) => Box::new([&repr.out].into_iter()),
1864 NumericOperationIr::DivScalar(repr) => Box::new([&repr.out].into_iter()),
1865 NumericOperationIr::Rem(repr) => Box::new([&repr.out].into_iter()),
1866 NumericOperationIr::RemScalar(repr) => Box::new([&repr.out].into_iter()),
1867 NumericOperationIr::GreaterElem(repr) => Box::new([&repr.out].into_iter()),
1868 NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.out].into_iter()),
1869 NumericOperationIr::LowerElem(repr) => Box::new([&repr.out].into_iter()),
1870 NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.out].into_iter()),
1871 NumericOperationIr::Greater(repr) => Box::new([&repr.out].into_iter()),
1872 NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.out].into_iter()),
1873 NumericOperationIr::Lower(repr) => Box::new([&repr.out].into_iter()),
1874 NumericOperationIr::LowerEqual(repr) => Box::new([&repr.out].into_iter()),
1875 NumericOperationIr::ArgMax(repr) => Box::new([&repr.out].into_iter()),
1876 NumericOperationIr::ArgMin(repr) => Box::new([&repr.out].into_iter()),
1877 NumericOperationIr::Clamp(repr) => Box::new([&repr.out].into_iter()),
1878 NumericOperationIr::Abs(repr) => Box::new([&repr.out].into_iter()),
1879 NumericOperationIr::Full(repr) => Box::new([&repr.out].into_iter()),
1880 NumericOperationIr::MeanDim(repr) => Box::new([&repr.out].into_iter()),
1881 NumericOperationIr::Mean(repr) => Box::new([&repr.out].into_iter()),
1882 NumericOperationIr::Sum(repr) => Box::new([&repr.out].into_iter()),
1883 NumericOperationIr::SumDim(repr) => Box::new([&repr.out].into_iter()),
1884 NumericOperationIr::Prod(repr) => Box::new([&repr.out].into_iter()),
1885 NumericOperationIr::ProdDim(repr) => Box::new([&repr.out].into_iter()),
1886 NumericOperationIr::Max(repr) => Box::new([&repr.out].into_iter()),
1887 NumericOperationIr::MaxDimWithIndices(repr) => {
1888 Box::new([&repr.out, &repr.out_indices].into_iter())
1889 }
1890 NumericOperationIr::MinDimWithIndices(repr) => {
1891 Box::new([&repr.out, &repr.out_indices].into_iter())
1892 }
1893 NumericOperationIr::Min(repr) => Box::new([&repr.out].into_iter()),
1894 NumericOperationIr::MaxDim(repr) => Box::new([&repr.out].into_iter()),
1895 NumericOperationIr::MinDim(repr) => Box::new([&repr.out].into_iter()),
1896 NumericOperationIr::MaxAbs(repr) => Box::new([&repr.out].into_iter()),
1897 NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.out].into_iter()),
1898 NumericOperationIr::IntRandom(repr) => Box::new([&repr.out].into_iter()),
1899 NumericOperationIr::Powf(repr) => Box::new([&repr.out].into_iter()),
1900 NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
1901 NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
1902 NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
1903 NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
1904 }
1905 }
1906 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1907 let mut output = Vec::new();
1908
1909 match self {
1910 NumericOperationIr::Add(repr) => {
1911 repr.lhs.mark_read_only(nodes, &mut output);
1912 repr.rhs.mark_read_only(nodes, &mut output);
1913 }
1914 NumericOperationIr::AddScalar(repr) => {
1915 repr.lhs.mark_read_only(nodes, &mut output);
1916 }
1917 NumericOperationIr::Sub(repr) => {
1918 repr.lhs.mark_read_only(nodes, &mut output);
1919 repr.rhs.mark_read_only(nodes, &mut output);
1920 }
1921 NumericOperationIr::SubScalar(repr) => {
1922 repr.lhs.mark_read_only(nodes, &mut output);
1923 }
1924 NumericOperationIr::Mul(repr) => {
1925 repr.lhs.mark_read_only(nodes, &mut output);
1926 repr.rhs.mark_read_only(nodes, &mut output);
1927 }
1928 NumericOperationIr::MulScalar(repr) => {
1929 repr.lhs.mark_read_only(nodes, &mut output);
1930 }
1931 NumericOperationIr::Div(repr) => {
1932 repr.lhs.mark_read_only(nodes, &mut output);
1933 repr.rhs.mark_read_only(nodes, &mut output);
1934 }
1935 NumericOperationIr::DivScalar(repr) => {
1936 repr.lhs.mark_read_only(nodes, &mut output);
1937 }
1938 NumericOperationIr::Rem(repr) => {
1939 repr.lhs.mark_read_only(nodes, &mut output);
1940 repr.rhs.mark_read_only(nodes, &mut output);
1941 }
1942 NumericOperationIr::RemScalar(repr) => {
1943 repr.lhs.mark_read_only(nodes, &mut output);
1944 }
1945 NumericOperationIr::GreaterElem(repr) => {
1946 repr.lhs.mark_read_only(nodes, &mut output);
1947 }
1948 NumericOperationIr::GreaterEqualElem(repr) => {
1949 repr.lhs.mark_read_only(nodes, &mut output);
1950 }
1951 NumericOperationIr::LowerElem(repr) => {
1952 repr.lhs.mark_read_only(nodes, &mut output);
1953 }
1954 NumericOperationIr::LowerEqualElem(repr) => {
1955 repr.lhs.mark_read_only(nodes, &mut output);
1956 }
1957 NumericOperationIr::Greater(repr) => {
1958 repr.lhs.mark_read_only(nodes, &mut output);
1959 repr.rhs.mark_read_only(nodes, &mut output);
1960 }
1961 NumericOperationIr::GreaterEqual(repr) => {
1962 repr.lhs.mark_read_only(nodes, &mut output);
1963 repr.rhs.mark_read_only(nodes, &mut output);
1964 }
1965 NumericOperationIr::Lower(repr) => {
1966 repr.lhs.mark_read_only(nodes, &mut output);
1967 repr.rhs.mark_read_only(nodes, &mut output);
1968 }
1969 NumericOperationIr::LowerEqual(repr) => {
1970 repr.lhs.mark_read_only(nodes, &mut output);
1971 repr.rhs.mark_read_only(nodes, &mut output);
1972 }
1973 NumericOperationIr::ArgMax(repr) => {
1974 repr.input.mark_read_only(nodes, &mut output);
1975 }
1976 NumericOperationIr::ArgMin(repr) => {
1977 repr.input.mark_read_only(nodes, &mut output);
1978 }
1979 NumericOperationIr::Clamp(repr) => {
1980 repr.tensor.mark_read_only(nodes, &mut output);
1981 }
1982 NumericOperationIr::Abs(repr) => {
1983 repr.input.mark_read_only(nodes, &mut output);
1984 }
1985 NumericOperationIr::Full(_) => {}
1986 NumericOperationIr::MeanDim(repr) => {
1987 repr.input.mark_read_only(nodes, &mut output);
1988 }
1989 NumericOperationIr::Mean(repr) => {
1990 repr.input.mark_read_only(nodes, &mut output);
1991 }
1992 NumericOperationIr::Sum(repr) => {
1993 repr.input.mark_read_only(nodes, &mut output);
1994 }
1995 NumericOperationIr::SumDim(repr) => {
1996 repr.input.mark_read_only(nodes, &mut output);
1997 }
1998 NumericOperationIr::Prod(repr) => {
1999 repr.input.mark_read_only(nodes, &mut output);
2000 }
2001 NumericOperationIr::ProdDim(repr) => {
2002 repr.input.mark_read_only(nodes, &mut output);
2003 }
2004 NumericOperationIr::Max(repr) => {
2005 repr.input.mark_read_only(nodes, &mut output);
2006 }
2007 NumericOperationIr::MaxDimWithIndices(repr) => {
2008 repr.tensor.mark_read_only(nodes, &mut output);
2009 }
2010 NumericOperationIr::MinDimWithIndices(repr) => {
2011 repr.tensor.mark_read_only(nodes, &mut output);
2012 }
2013 NumericOperationIr::Min(repr) => {
2014 repr.input.mark_read_only(nodes, &mut output);
2015 }
2016 NumericOperationIr::MaxDim(repr) => {
2017 repr.input.mark_read_only(nodes, &mut output);
2018 }
2019 NumericOperationIr::MinDim(repr) => {
2020 repr.input.mark_read_only(nodes, &mut output);
2021 }
2022 NumericOperationIr::MaxAbs(repr) => {
2023 repr.input.mark_read_only(nodes, &mut output);
2024 }
2025 NumericOperationIr::MaxAbsDim(repr) => {
2026 repr.input.mark_read_only(nodes, &mut output);
2027 }
2028 NumericOperationIr::IntRandom(_) => {}
2029 NumericOperationIr::Powf(repr) => {
2030 repr.lhs.mark_read_only(nodes, &mut output);
2031 repr.rhs.mark_read_only(nodes, &mut output);
2032 }
2033 NumericOperationIr::CumSum(repr) => {
2034 repr.input.mark_read_only(nodes, &mut output);
2035 }
2036 NumericOperationIr::CumProd(repr) => {
2037 repr.input.mark_read_only(nodes, &mut output);
2038 }
2039 NumericOperationIr::CumMin(repr) => {
2040 repr.input.mark_read_only(nodes, &mut output);
2041 }
2042 NumericOperationIr::CumMax(repr) => {
2043 repr.input.mark_read_only(nodes, &mut output);
2044 }
2045 };
2046
2047 output
2048 }
2049}
2050
2051impl FloatOperationIr {
2052 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2053 match self {
2054 FloatOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2055 FloatOperationIr::Cross(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2056 FloatOperationIr::Random(_repr) => Box::new([].into_iter()),
2057 FloatOperationIr::Exp(repr) => Box::new([&repr.input].into_iter()),
2058 FloatOperationIr::Log(repr) => Box::new([&repr.input].into_iter()),
2059 FloatOperationIr::Log1p(repr) => Box::new([&repr.input].into_iter()),
2060 FloatOperationIr::Erf(repr) => Box::new([&repr.input].into_iter()),
2061 FloatOperationIr::Recip(repr) => Box::new([&repr.input].into_iter()),
2062 FloatOperationIr::PowfScalar(repr) => Box::new([&repr.lhs].into_iter()),
2063 FloatOperationIr::Sqrt(repr) => Box::new([&repr.input].into_iter()),
2064 FloatOperationIr::Cos(repr) => Box::new([&repr.input].into_iter()),
2065 FloatOperationIr::Sin(repr) => Box::new([&repr.input].into_iter()),
2066 FloatOperationIr::Tanh(repr) => Box::new([&repr.input].into_iter()),
2067 FloatOperationIr::Round(repr) => Box::new([&repr.input].into_iter()),
2068 FloatOperationIr::Floor(repr) => Box::new([&repr.input].into_iter()),
2069 FloatOperationIr::Ceil(repr) => Box::new([&repr.input].into_iter()),
2070 FloatOperationIr::Trunc(repr) => Box::new([&repr.input].into_iter()),
2071 FloatOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2072 FloatOperationIr::Quantize(repr) => {
2073 Box::new([&repr.tensor, &repr.qparams.scales].into_iter())
2074 }
2075 FloatOperationIr::Dequantize(repr) => Box::new([&repr.input].into_iter()),
2076 FloatOperationIr::IsNan(repr) => Box::new([&repr.input].into_iter()),
2077 FloatOperationIr::IsInf(repr) => Box::new([&repr.input].into_iter()),
2078 FloatOperationIr::GridSample2d(repr) => {
2079 Box::new([&repr.tensor, &repr.grid].into_iter())
2080 }
2081 FloatOperationIr::Tan(repr) => Box::new([&repr.input].into_iter()),
2082 FloatOperationIr::Cosh(repr) => Box::new([&repr.input].into_iter()),
2083 FloatOperationIr::Sinh(repr) => Box::new([&repr.input].into_iter()),
2084 FloatOperationIr::ArcCos(repr) => Box::new([&repr.input].into_iter()),
2085 FloatOperationIr::ArcCosh(repr) => Box::new([&repr.input].into_iter()),
2086 FloatOperationIr::ArcSin(repr) => Box::new([&repr.input].into_iter()),
2087 FloatOperationIr::ArcSinh(repr) => Box::new([&repr.input].into_iter()),
2088 FloatOperationIr::ArcTan(repr) => Box::new([&repr.input].into_iter()),
2089 FloatOperationIr::ArcTanh(repr) => Box::new([&repr.input].into_iter()),
2090 FloatOperationIr::ArcTan2(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2091 }
2092 }
2093 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2094 match self {
2095 FloatOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2096 FloatOperationIr::Cross(repr) => Box::new([&repr.out].into_iter()),
2097 FloatOperationIr::Random(repr) => Box::new([&repr.out].into_iter()),
2098 FloatOperationIr::Exp(repr) => Box::new([&repr.out].into_iter()),
2099 FloatOperationIr::Log(repr) => Box::new([&repr.out].into_iter()),
2100 FloatOperationIr::Log1p(repr) => Box::new([&repr.out].into_iter()),
2101 FloatOperationIr::Erf(repr) => Box::new([&repr.out].into_iter()),
2102 FloatOperationIr::Recip(repr) => Box::new([&repr.out].into_iter()),
2103 FloatOperationIr::PowfScalar(repr) => Box::new([&repr.out].into_iter()),
2104 FloatOperationIr::Sqrt(repr) => Box::new([&repr.out].into_iter()),
2105 FloatOperationIr::Cos(repr) => Box::new([&repr.out].into_iter()),
2106 FloatOperationIr::Sin(repr) => Box::new([&repr.out].into_iter()),
2107 FloatOperationIr::Tanh(repr) => Box::new([&repr.out].into_iter()),
2108 FloatOperationIr::Round(repr) => Box::new([&repr.out].into_iter()),
2109 FloatOperationIr::Floor(repr) => Box::new([&repr.out].into_iter()),
2110 FloatOperationIr::Ceil(repr) => Box::new([&repr.out].into_iter()),
2111 FloatOperationIr::Trunc(repr) => Box::new([&repr.out].into_iter()),
2112 FloatOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2113 FloatOperationIr::Quantize(repr) => Box::new([&repr.out].into_iter()),
2114 FloatOperationIr::Dequantize(repr) => Box::new([&repr.out].into_iter()),
2115 FloatOperationIr::IsNan(repr) => Box::new([&repr.out].into_iter()),
2116 FloatOperationIr::IsInf(repr) => Box::new([&repr.out].into_iter()),
2117 FloatOperationIr::GridSample2d(repr) => Box::new([&repr.out].into_iter()),
2118 FloatOperationIr::Tan(repr) => Box::new([&repr.out].into_iter()),
2119 FloatOperationIr::Cosh(repr) => Box::new([&repr.out].into_iter()),
2120 FloatOperationIr::Sinh(repr) => Box::new([&repr.out].into_iter()),
2121 FloatOperationIr::ArcCos(repr) => Box::new([&repr.out].into_iter()),
2122 FloatOperationIr::ArcCosh(repr) => Box::new([&repr.out].into_iter()),
2123 FloatOperationIr::ArcSin(repr) => Box::new([&repr.out].into_iter()),
2124 FloatOperationIr::ArcSinh(repr) => Box::new([&repr.out].into_iter()),
2125 FloatOperationIr::ArcTan(repr) => Box::new([&repr.out].into_iter()),
2126 FloatOperationIr::ArcTanh(repr) => Box::new([&repr.out].into_iter()),
2127 FloatOperationIr::ArcTan2(repr) => Box::new([&repr.out].into_iter()),
2128 }
2129 }
2130
2131 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2132 let mut output = Vec::new();
2133
2134 match self {
2135 FloatOperationIr::Matmul(repr) => {
2136 repr.lhs.mark_read_only(nodes, &mut output);
2137 repr.rhs.mark_read_only(nodes, &mut output);
2138 }
2139 FloatOperationIr::Cross(repr) => {
2140 repr.lhs.mark_read_only(nodes, &mut output);
2141 repr.rhs.mark_read_only(nodes, &mut output);
2142 }
2143 FloatOperationIr::Random(_) => {}
2144 FloatOperationIr::Exp(repr) => {
2145 repr.input.mark_read_only(nodes, &mut output);
2146 }
2147 FloatOperationIr::Log(repr) => {
2148 repr.input.mark_read_only(nodes, &mut output);
2149 }
2150 FloatOperationIr::Log1p(repr) => {
2151 repr.input.mark_read_only(nodes, &mut output);
2152 }
2153 FloatOperationIr::Erf(repr) => {
2154 repr.input.mark_read_only(nodes, &mut output);
2155 }
2156 FloatOperationIr::Recip(repr) => {
2157 repr.input.mark_read_only(nodes, &mut output);
2158 }
2159 FloatOperationIr::PowfScalar(repr) => {
2160 repr.lhs.mark_read_only(nodes, &mut output);
2161 }
2162 FloatOperationIr::Sqrt(repr) => {
2163 repr.input.mark_read_only(nodes, &mut output);
2164 }
2165 FloatOperationIr::Cos(repr) => {
2166 repr.input.mark_read_only(nodes, &mut output);
2167 }
2168 FloatOperationIr::Sin(repr) => {
2169 repr.input.mark_read_only(nodes, &mut output);
2170 }
2171 FloatOperationIr::Tanh(repr) => {
2172 repr.input.mark_read_only(nodes, &mut output);
2173 }
2174 FloatOperationIr::Round(repr) => {
2175 repr.input.mark_read_only(nodes, &mut output);
2176 }
2177 FloatOperationIr::Floor(repr) => {
2178 repr.input.mark_read_only(nodes, &mut output);
2179 }
2180 FloatOperationIr::Ceil(repr) => {
2181 repr.input.mark_read_only(nodes, &mut output);
2182 }
2183 FloatOperationIr::Trunc(repr) => {
2184 repr.input.mark_read_only(nodes, &mut output);
2185 }
2186 FloatOperationIr::Quantize(repr) => {
2187 repr.tensor.mark_read_only(nodes, &mut output);
2188 repr.qparams.scales.mark_read_only(nodes, &mut output);
2189 }
2190 FloatOperationIr::Dequantize(repr) => {
2191 repr.input.mark_read_only(nodes, &mut output);
2192 }
2193 FloatOperationIr::IntoInt(repr) => {
2194 repr.input.mark_read_only(nodes, &mut output);
2195 }
2196 FloatOperationIr::IsNan(repr) => {
2197 repr.input.mark_read_only(nodes, &mut output);
2198 }
2199 FloatOperationIr::IsInf(repr) => {
2200 repr.input.mark_read_only(nodes, &mut output);
2201 }
2202 FloatOperationIr::GridSample2d(repr) => {
2203 repr.tensor.mark_read_only(nodes, &mut output);
2204 repr.grid.mark_read_only(nodes, &mut output);
2205 }
2206 FloatOperationIr::Tan(repr) => repr.input.mark_read_only(nodes, &mut output),
2207 FloatOperationIr::Cosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2208 FloatOperationIr::Sinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2209 FloatOperationIr::ArcCos(repr) => repr.input.mark_read_only(nodes, &mut output),
2210 FloatOperationIr::ArcCosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2211 FloatOperationIr::ArcSin(repr) => repr.input.mark_read_only(nodes, &mut output),
2212 FloatOperationIr::ArcSinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2213 FloatOperationIr::ArcTan(repr) => repr.input.mark_read_only(nodes, &mut output),
2214 FloatOperationIr::ArcTanh(repr) => repr.input.mark_read_only(nodes, &mut output),
2215 FloatOperationIr::ArcTan2(repr) => {
2216 repr.lhs.mark_read_only(nodes, &mut output);
2217 repr.rhs.mark_read_only(nodes, &mut output);
2218 }
2219 };
2220
2221 output
2222 }
2223}
2224
2225impl IntOperationIr {
2226 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2227 match self {
2228 IntOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2229 IntOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2230 IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2231 IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.lhs].into_iter()),
2232 IntOperationIr::BitwiseOr(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2233 IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.lhs].into_iter()),
2234 IntOperationIr::BitwiseXor(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2235 IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.lhs].into_iter()),
2236 IntOperationIr::BitwiseNot(repr) => Box::new([&repr.input].into_iter()),
2237 IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2238 IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2239 IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2240 IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2241 }
2242 }
2243
2244 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2245 match self {
2246 IntOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2247 IntOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2248 IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.out].into_iter()),
2249 IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.out].into_iter()),
2250 IntOperationIr::BitwiseOr(repr) => Box::new([&repr.out].into_iter()),
2251 IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.out].into_iter()),
2252 IntOperationIr::BitwiseXor(repr) => Box::new([&repr.out].into_iter()),
2253 IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.out].into_iter()),
2254 IntOperationIr::BitwiseNot(repr) => Box::new([&repr.out].into_iter()),
2255 IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.out].into_iter()),
2256 IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2257 IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.out].into_iter()),
2258 IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2259 }
2260 }
2261
2262 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2263 let mut output = Vec::new();
2264
2265 match self {
2266 IntOperationIr::Matmul(repr) => {
2267 repr.lhs.mark_read_only(nodes, &mut output);
2268 repr.rhs.mark_read_only(nodes, &mut output);
2269 }
2270 IntOperationIr::IntoFloat(repr) => {
2271 repr.input.mark_read_only(nodes, &mut output);
2272 }
2273 IntOperationIr::BitwiseAnd(repr) => {
2274 repr.lhs.mark_read_only(nodes, &mut output);
2275 repr.rhs.mark_read_only(nodes, &mut output);
2276 }
2277 IntOperationIr::BitwiseAndScalar(repr) => {
2278 repr.lhs.mark_read_only(nodes, &mut output);
2279 }
2280 IntOperationIr::BitwiseOr(repr) => {
2281 repr.lhs.mark_read_only(nodes, &mut output);
2282 repr.rhs.mark_read_only(nodes, &mut output);
2283 }
2284 IntOperationIr::BitwiseOrScalar(repr) => {
2285 repr.lhs.mark_read_only(nodes, &mut output);
2286 }
2287 IntOperationIr::BitwiseXor(repr) => {
2288 repr.lhs.mark_read_only(nodes, &mut output);
2289 repr.rhs.mark_read_only(nodes, &mut output);
2290 }
2291 IntOperationIr::BitwiseXorScalar(repr) => {
2292 repr.lhs.mark_read_only(nodes, &mut output);
2293 }
2294 IntOperationIr::BitwiseNot(repr) => {
2295 repr.input.mark_read_only(nodes, &mut output);
2296 }
2297 IntOperationIr::BitwiseLeftShift(repr) => {
2298 repr.lhs.mark_read_only(nodes, &mut output);
2299 repr.rhs.mark_read_only(nodes, &mut output);
2300 }
2301 IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2302 repr.lhs.mark_read_only(nodes, &mut output);
2303 }
2304 IntOperationIr::BitwiseRightShift(repr) => {
2305 repr.lhs.mark_read_only(nodes, &mut output);
2306 repr.rhs.mark_read_only(nodes, &mut output);
2307 }
2308 IntOperationIr::BitwiseRightShiftScalar(repr) => {
2309 repr.lhs.mark_read_only(nodes, &mut output);
2310 }
2311 };
2312
2313 output
2314 }
2315}
2316
2317impl BoolOperationIr {
2318 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2319 match self {
2320 BoolOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2321 BoolOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2322 BoolOperationIr::Not(repr) => Box::new([&repr.input].into_iter()),
2323 BoolOperationIr::And(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2324 BoolOperationIr::Or(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2325 }
2326 }
2327 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2328 match self {
2329 BoolOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2330 BoolOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2331 BoolOperationIr::Not(repr) => Box::new([&repr.out].into_iter()),
2332 BoolOperationIr::And(repr) => Box::new([&repr.out].into_iter()),
2333 BoolOperationIr::Or(repr) => Box::new([&repr.out].into_iter()),
2334 }
2335 }
2336 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2337 let mut output = Vec::new();
2338
2339 match self {
2340 BoolOperationIr::IntoFloat(repr) => {
2341 repr.input.mark_read_only(nodes, &mut output);
2342 }
2343 BoolOperationIr::IntoInt(repr) => {
2344 repr.input.mark_read_only(nodes, &mut output);
2345 }
2346 BoolOperationIr::Not(repr) => {
2347 repr.input.mark_read_only(nodes, &mut output);
2348 }
2349 BoolOperationIr::And(repr) => {
2350 repr.lhs.mark_read_only(nodes, &mut output);
2351 repr.rhs.mark_read_only(nodes, &mut output);
2352 }
2353 BoolOperationIr::Or(repr) => {
2354 repr.lhs.mark_read_only(nodes, &mut output);
2355 repr.rhs.mark_read_only(nodes, &mut output);
2356 }
2357 };
2358
2359 output
2360 }
2361}
2362
2363impl ModuleOperationIr {
2364 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2365 match self {
2366 ModuleOperationIr::Embedding(repr) => {
2367 Box::new([&repr.weights, &repr.indices].into_iter())
2368 }
2369 ModuleOperationIr::EmbeddingBackward(repr) => {
2370 Box::new([&repr.weights, &repr.out_grad, &repr.indices].into_iter())
2371 }
2372 ModuleOperationIr::Conv1d(repr) => {
2373 if let Some(bias) = &repr.bias {
2374 Box::new([&repr.x, &repr.weight, bias].into_iter())
2375 } else {
2376 Box::new([&repr.x, &repr.weight].into_iter())
2377 }
2378 }
2379 ModuleOperationIr::Conv2d(repr) => {
2380 if let Some(bias) = &repr.bias {
2381 Box::new([&repr.x, &repr.weight, bias].into_iter())
2382 } else {
2383 Box::new([&repr.x, &repr.weight].into_iter())
2384 }
2385 }
2386 ModuleOperationIr::Conv3d(repr) => {
2387 if let Some(bias) = &repr.bias {
2388 Box::new([&repr.x, &repr.weight, bias].into_iter())
2389 } else {
2390 Box::new([&repr.x, &repr.weight].into_iter())
2391 }
2392 }
2393 ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {
2394 (Some(mask), Some(bias)) => {
2395 Box::new([&repr.x, &repr.offset, &repr.weight, mask, bias].into_iter())
2396 }
2397 (Some(mask), None) => {
2398 Box::new([&repr.x, &repr.offset, &repr.weight, mask].into_iter())
2399 }
2400 (None, Some(bias)) => {
2401 Box::new([&repr.x, &repr.offset, &repr.weight, bias].into_iter())
2402 }
2403 (None, None) => Box::new([&repr.x, &repr.offset, &repr.weight].into_iter()),
2404 },
2405 ModuleOperationIr::DeformableConv2dBackward(repr) => match (&repr.mask, &repr.bias) {
2406 (Some(mask), Some(bias)) => Box::new(
2407 [
2408 &repr.x,
2409 &repr.offset,
2410 &repr.weight,
2411 &repr.out_grad,
2412 mask,
2413 bias,
2414 ]
2415 .into_iter(),
2416 ),
2417 (Some(mask), None) => Box::new(
2418 [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask].into_iter(),
2419 ),
2420 (None, Some(bias)) => Box::new(
2421 [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, bias].into_iter(),
2422 ),
2423 (None, None) => {
2424 Box::new([&repr.x, &repr.offset, &repr.weight, &repr.out_grad].into_iter())
2425 }
2426 },
2427 ModuleOperationIr::ConvTranspose1d(repr) => {
2428 if let Some(bias) = &repr.bias {
2429 Box::new([&repr.x, &repr.weight, bias].into_iter())
2430 } else {
2431 Box::new([&repr.x, &repr.weight].into_iter())
2432 }
2433 }
2434 ModuleOperationIr::ConvTranspose2d(repr) => {
2435 if let Some(bias) = &repr.bias {
2436 Box::new([&repr.x, &repr.weight, bias].into_iter())
2437 } else {
2438 Box::new([&repr.x, &repr.weight].into_iter())
2439 }
2440 }
2441 ModuleOperationIr::ConvTranspose3d(repr) => {
2442 if let Some(bias) = &repr.bias {
2443 Box::new([&repr.x, &repr.weight, bias].into_iter())
2444 } else {
2445 Box::new([&repr.x, &repr.weight].into_iter())
2446 }
2447 }
2448 ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2449 ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2450 ModuleOperationIr::AvgPool1dBackward(repr) => {
2451 Box::new([&repr.x, &repr.grad].into_iter())
2452 }
2453 ModuleOperationIr::AvgPool2dBackward(repr) => {
2454 Box::new([&repr.x, &repr.grad].into_iter())
2455 }
2456 ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2457 ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2458 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2459 Box::new([&repr.x, &repr.grad].into_iter())
2460 }
2461 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2462 Box::new([&repr.x, &repr.grad].into_iter())
2463 }
2464 ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.x].into_iter()),
2465 ModuleOperationIr::MaxPool1dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2466 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2467 Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2468 }
2469 ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.x].into_iter()),
2470 ModuleOperationIr::MaxPool2dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2471 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2472 Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2473 }
2474 ModuleOperationIr::Interpolate(repr) => Box::new([&repr.x].into_iter()),
2475 ModuleOperationIr::InterpolateBackward(repr) => {
2476 Box::new([&repr.x, &repr.grad].into_iter())
2477 }
2478 }
2479 }
2480 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2481 match self {
2482 ModuleOperationIr::Embedding(repr) => Box::new([&repr.out].into_iter()),
2483 ModuleOperationIr::EmbeddingBackward(repr) => Box::new([&repr.out].into_iter()),
2484 ModuleOperationIr::Conv1d(repr) => Box::new([&repr.out].into_iter()),
2485 ModuleOperationIr::Conv2d(repr) => Box::new([&repr.out].into_iter()),
2486 ModuleOperationIr::Conv3d(repr) => Box::new([&repr.out].into_iter()),
2487 ModuleOperationIr::DeformableConv2d(repr) => Box::new([&repr.out].into_iter()),
2488 ModuleOperationIr::DeformableConv2dBackward(repr) => {
2489 match (&repr.mask_grad, &repr.bias_grad) {
2490 (Some(mask_grad), Some(bias_grad)) => Box::new(
2491 [
2492 &repr.input_grad,
2493 &repr.offset_grad,
2494 &repr.weight_grad,
2495 mask_grad,
2496 bias_grad,
2497 ]
2498 .into_iter(),
2499 ),
2500 (Some(mask_grad), None) => Box::new(
2501 [
2502 &repr.input_grad,
2503 &repr.offset_grad,
2504 &repr.weight_grad,
2505 mask_grad,
2506 ]
2507 .into_iter(),
2508 ),
2509 (None, Some(bias_grad)) => Box::new(
2510 [
2511 &repr.input_grad,
2512 &repr.offset_grad,
2513 &repr.weight_grad,
2514 bias_grad,
2515 ]
2516 .into_iter(),
2517 ),
2518 (None, None) => Box::new(
2519 [&repr.input_grad, &repr.offset_grad, &repr.weight_grad].into_iter(),
2520 ),
2521 }
2522 }
2523 ModuleOperationIr::ConvTranspose1d(repr) => Box::new([&repr.out].into_iter()),
2524 ModuleOperationIr::ConvTranspose2d(repr) => Box::new([&repr.out].into_iter()),
2525 ModuleOperationIr::ConvTranspose3d(repr) => Box::new([&repr.out].into_iter()),
2526 ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.out].into_iter()),
2527 ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.out].into_iter()),
2528 ModuleOperationIr::AvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
2529 ModuleOperationIr::AvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
2530 ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.out].into_iter()),
2531 ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.out].into_iter()),
2532 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
2533 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
2534 ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.out].into_iter()),
2535 ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2536 Box::new([&repr.out, &repr.out_indices].into_iter())
2537 }
2538 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2539 Box::new([&repr.out].into_iter())
2540 }
2541 ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.out].into_iter()),
2542 ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2543 Box::new([&repr.out, &repr.out_indices].into_iter())
2544 }
2545 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2546 Box::new([&repr.out].into_iter())
2547 }
2548 ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()),
2549 ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()),
2550 }
2551 }
2552
2553 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2554 let mut output = Vec::new();
2555
2556 match self {
2557 ModuleOperationIr::Embedding(repr) => {
2558 repr.weights.mark_read_only(nodes, &mut output);
2559 repr.indices.mark_read_only(nodes, &mut output);
2560 }
2561 ModuleOperationIr::EmbeddingBackward(repr) => {
2562 repr.weights.mark_read_only(nodes, &mut output);
2563 repr.out_grad.mark_read_only(nodes, &mut output);
2564 repr.indices.mark_read_only(nodes, &mut output);
2565 }
2566 ModuleOperationIr::Conv1d(repr) => {
2567 repr.x.mark_read_only(nodes, &mut output);
2568 repr.weight.mark_read_only(nodes, &mut output);
2569
2570 if let Some(bias) = &mut repr.bias {
2571 bias.mark_read_only(nodes, &mut output);
2572 }
2573 }
2574 ModuleOperationIr::Conv2d(repr) => {
2575 repr.x.mark_read_only(nodes, &mut output);
2576 repr.weight.mark_read_only(nodes, &mut output);
2577
2578 if let Some(bias) = &mut repr.bias {
2579 bias.mark_read_only(nodes, &mut output);
2580 }
2581 }
2582 ModuleOperationIr::Conv3d(repr) => {
2583 repr.x.mark_read_only(nodes, &mut output);
2584 repr.weight.mark_read_only(nodes, &mut output);
2585
2586 if let Some(bias) = &mut repr.bias {
2587 bias.mark_read_only(nodes, &mut output);
2588 }
2589 }
2590 ModuleOperationIr::DeformableConv2d(repr) => {
2591 repr.x.mark_read_only(nodes, &mut output);
2592 repr.weight.mark_read_only(nodes, &mut output);
2593 repr.offset.mark_read_only(nodes, &mut output);
2594
2595 match (&mut repr.mask, &mut repr.bias) {
2596 (Some(mask), Some(bias)) => {
2597 mask.mark_read_only(nodes, &mut output);
2598 bias.mark_read_only(nodes, &mut output);
2599 }
2600 (Some(mask), None) => {
2601 mask.mark_read_only(nodes, &mut output);
2602 }
2603 (None, Some(bias)) => {
2604 bias.mark_read_only(nodes, &mut output);
2605 }
2606 (None, None) => {}
2607 };
2608 }
2609 ModuleOperationIr::DeformableConv2dBackward(repr) => {
2610 repr.x.mark_read_only(nodes, &mut output);
2611 repr.weight.mark_read_only(nodes, &mut output);
2612 repr.offset.mark_read_only(nodes, &mut output);
2613 repr.out_grad.mark_read_only(nodes, &mut output);
2614
2615 if let Some(mask) = repr.mask.as_mut() {
2616 mask.mark_read_only(nodes, &mut output);
2617 }
2618 if let Some(bias) = repr.bias.as_mut() {
2619 bias.mark_read_only(nodes, &mut output);
2620 }
2621 }
2622 ModuleOperationIr::ConvTranspose1d(repr) => {
2623 repr.x.mark_read_only(nodes, &mut output);
2624 repr.weight.mark_read_only(nodes, &mut output);
2625
2626 if let Some(bias) = &mut repr.bias {
2627 bias.mark_read_only(nodes, &mut output);
2628 }
2629 }
2630 ModuleOperationIr::ConvTranspose2d(repr) => {
2631 repr.x.mark_read_only(nodes, &mut output);
2632 repr.weight.mark_read_only(nodes, &mut output);
2633
2634 if let Some(bias) = &mut repr.bias {
2635 bias.mark_read_only(nodes, &mut output);
2636 }
2637 }
2638 ModuleOperationIr::ConvTranspose3d(repr) => {
2639 repr.x.mark_read_only(nodes, &mut output);
2640 repr.weight.mark_read_only(nodes, &mut output);
2641
2642 if let Some(bias) = &mut repr.bias {
2643 bias.mark_read_only(nodes, &mut output);
2644 }
2645 }
2646 ModuleOperationIr::AvgPool1d(repr) => {
2647 repr.x.mark_read_only(nodes, &mut output);
2648 }
2649 ModuleOperationIr::AvgPool2d(repr) => {
2650 repr.x.mark_read_only(nodes, &mut output);
2651 }
2652 ModuleOperationIr::AvgPool1dBackward(repr) => {
2653 repr.x.mark_read_only(nodes, &mut output);
2654 repr.grad.mark_read_only(nodes, &mut output);
2655 }
2656 ModuleOperationIr::AvgPool2dBackward(repr) => {
2657 repr.x.mark_read_only(nodes, &mut output);
2658 repr.grad.mark_read_only(nodes, &mut output);
2659 }
2660 ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
2661 repr.x.mark_read_only(nodes, &mut output);
2662 }
2663 ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
2664 repr.x.mark_read_only(nodes, &mut output);
2665 }
2666 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2667 repr.x.mark_read_only(nodes, &mut output);
2668 repr.grad.mark_read_only(nodes, &mut output);
2669 }
2670 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2671 repr.x.mark_read_only(nodes, &mut output);
2672 repr.grad.mark_read_only(nodes, &mut output);
2673 }
2674 ModuleOperationIr::MaxPool1d(repr) => {
2675 repr.x.mark_read_only(nodes, &mut output);
2676 }
2677 ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2678 repr.x.mark_read_only(nodes, &mut output);
2679 }
2680 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2681 repr.x.mark_read_only(nodes, &mut output);
2682 repr.grad.mark_read_only(nodes, &mut output);
2683 }
2684 ModuleOperationIr::MaxPool2d(repr) => {
2685 repr.x.mark_read_only(nodes, &mut output);
2686 }
2687 ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2688 repr.x.mark_read_only(nodes, &mut output);
2689 }
2690 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2691 repr.x.mark_read_only(nodes, &mut output);
2692 repr.grad.mark_read_only(nodes, &mut output);
2693 }
2694 ModuleOperationIr::Interpolate(repr) => {
2695 repr.x.mark_read_only(nodes, &mut output);
2696 }
2697 ModuleOperationIr::InterpolateBackward(repr) => {
2698 repr.x.mark_read_only(nodes, &mut output);
2699 repr.grad.mark_read_only(nodes, &mut output);
2700 }
2701 };
2702
2703 output
2704 }
2705}
2706
2707impl InitOperationIr {
2708 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2709 Box::new([].into_iter())
2710 }
2711 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2712 Box::new([&self.out].into_iter())
2713 }
2714}
2715
2716impl TensorIr {
2717 fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {
2718 if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {
2719 output.push(self.clone());
2720 self.status = TensorStatus::ReadOnly;
2721 }
2722 }
2723}
2724
2725impl core::hash::Hash for RandomOpIr {
2726 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2727 self.out.hash(state);
2728
2729 match self.distribution {
2730 Distribution::Default => 1u8.hash(state),
2731 Distribution::Bernoulli(_) => 2u8.hash(state),
2732 Distribution::Uniform(_, _) => 3u8.hash(state),
2733 Distribution::Normal(_, _) => 4u8.hash(state),
2734 }
2735 }
2736}
2737
2738pub trait OperationOutput<O> {
2740 fn output(self) -> O;
2742
2743 fn outputs<const N: usize>(self) -> [O; N];
2745}
2746
2747impl<O: core::fmt::Debug> OperationOutput<O> for Vec<O> {
2748 fn output(self) -> O {
2749 let [tensor] = self.outputs();
2750 tensor
2751 }
2752
2753 fn outputs<const N: usize>(self) -> [O; N] {
2754 self.try_into().unwrap()
2755 }
2756}