1use burn_backend::ops::AttentionModuleOptions;
2use burn_backend::tensor::IndexingUpdateOp;
3use core::hash::Hash;
4use serde::{Deserialize, Serialize};
5
6use alloc::borrow::ToOwned;
7use alloc::boxed::Box;
8use alloc::{string::String, vec::Vec};
9
10use burn_backend::{
11 DType, Distribution, Slice,
12 ops::{
13 ConvOptions, ConvTransposeOptions, DeformConvOptions, GridSampleOptions,
14 GridSamplePaddingMode, InterpolateMode, InterpolateOptions,
15 },
16 quantization::QuantScheme,
17};
18
19use crate::{ScalarIr, TensorId, TensorIr, TensorStatus};
20
21#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
23pub struct CustomOpIr {
24 pub id: String,
26 pub inputs: Vec<TensorIr>,
28 pub outputs: Vec<TensorIr>,
30}
31
32impl CustomOpIr {
33 pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {
35 Self {
36 id: id.to_owned(),
37 inputs: inputs.to_vec(),
38 outputs: outputs.to_vec(),
39 }
40 }
41
42 pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(
44 &self,
45 ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {
46 (
47 self.inputs.as_slice().try_into().expect(
48 "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
49 ),
50 self.outputs.as_slice().try_into().expect(
51 "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
52 ),
53 )
54 }
55
56 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
57 Box::new(self.inputs.iter())
58 }
59
60 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
61 Box::new(self.outputs.iter())
62 }
63}
64
65#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
67#[allow(clippy::large_enum_variant)]
68pub enum OperationIr {
69 BaseFloat(BaseOperationIr),
71 BaseInt(BaseOperationIr),
73 BaseBool(BaseOperationIr),
75 NumericFloat(DType, NumericOperationIr),
77 NumericInt(DType, NumericOperationIr),
79 Bool(BoolOperationIr),
81 Int(IntOperationIr),
83 Float(DType, FloatOperationIr),
85 Module(ModuleOperationIr),
87 Init(InitOperationIr),
89 Custom(CustomOpIr),
91 Drop(TensorIr),
93 #[cfg(feature = "distributed")]
94 Distributed(DistributedOperationIr),
96}
97
98#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
100pub enum FloatOperationIr {
101 Exp(UnaryOpIr),
103 Log(UnaryOpIr),
105 Log1p(UnaryOpIr),
107 Erf(UnaryOpIr),
109 PowfScalar(ScalarOpIr),
111 Sqrt(UnaryOpIr),
113 Cos(UnaryOpIr),
115 Cosh(UnaryOpIr),
117 Sin(UnaryOpIr),
119 Sinh(UnaryOpIr),
121 Tan(UnaryOpIr),
123 Tanh(UnaryOpIr),
125 ArcCos(UnaryOpIr),
127 ArcCosh(UnaryOpIr),
129 ArcSin(UnaryOpIr),
131 ArcSinh(UnaryOpIr),
133 ArcTan(UnaryOpIr),
135 ArcTanh(UnaryOpIr),
137 ArcTan2(BinaryOpIr),
139 Round(UnaryOpIr),
141 Floor(UnaryOpIr),
143 Ceil(UnaryOpIr),
145 Trunc(UnaryOpIr),
147 IntoInt(CastOpIr),
149 Matmul(MatmulOpIr),
151 Cross(CrossOpIr),
153 Random(RandomOpIr),
155 Recip(UnaryOpIr),
157 IsNan(UnaryOpIr),
159 IsInf(UnaryOpIr),
161 Quantize(QuantizeOpIr),
163 Dequantize(DequantizeOpIr),
165 GridSample2d(GridSample2dOpIr),
167 Powf(BinaryOpIr),
169}
170
171#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
173pub enum ModuleOperationIr {
174 Embedding(EmbeddingOpIr),
176 EmbeddingBackward(EmbeddingBackwardOpIr),
178 Linear(LinearOpIr),
180 LinearXBackward(LinearXBackwardOpIr),
182 LinearWeightBackward(LinearWeightBackwardOpIr),
184 LinearBiasBackward(LinearBiasBackwardOpIr),
186 Conv1d(Conv1dOpIr),
188 Conv1dXBackward(Conv1dXBackwardOpIr),
190 Conv1dWeightBackward(Conv1dWeightBackwardOpIr),
192 Conv1dBiasBackward(Conv1dBiasBackwardOpIr),
194 Conv2d(Conv2dOpIr),
196 Conv2dXBackward(Conv2dXBackwardOpIr),
198 Conv2dWeightBackward(Conv2dWeightBackwardOpIr),
200 Conv2dBiasBackward(Conv2dBiasBackwardOpIr),
202 Conv3d(Conv3dOpIr),
204 Conv3dXBackward(Conv3dXBackwardOpIr),
206 Conv3dWeightBackward(Conv3dWeightBackwardOpIr),
208 Conv3dBiasBackward(Conv3dBiasBackwardOpIr),
210 DeformableConv2d(Box<DeformConv2dOpIr>),
212 DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),
214 ConvTranspose1d(ConvTranspose1dOpIr),
216 ConvTranspose2d(ConvTranspose2dOpIr),
218 ConvTranspose3d(ConvTranspose3dOpIr),
220 AvgPool1d(AvgPool1dOpIr),
222 AvgPool2d(AvgPool2dOpIr),
224 AvgPool1dBackward(AvgPool1dBackwardOpIr),
227 AvgPool2dBackward(AvgPool2dBackwardOpIr),
230 AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),
233 AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),
236 AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),
239 AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),
242 MaxPool1d(MaxPool1dOpIr),
245 MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),
248 MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),
251 MaxPool2d(MaxPool2dOpIr),
254 MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),
257 MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),
260 Interpolate(InterpolateOpIr),
262 InterpolateBackward(InterpolateBackwardOpIr),
264 Rfft(RfftOpIr),
266 IRfft(IRfftOpIr),
268 Attention(AttentionOpIr),
270 CtcLoss(CtcLossOpIr),
272 CtcLossBackward(CtcLossBackwardOpIr),
275}
276
277#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
279pub enum BaseOperationIr {
280 Reshape(ShapeOpIr),
286
287 SwapDims(SwapDimsOpIr),
293
294 Permute(PermuteOpIr),
300
301 Flip(FlipOpIr),
306
307 Expand(ShapeOpIr),
313
314 Unfold(UnfoldOpIr),
317
318 Slice(SliceOpIr),
324 SliceAssign(SliceAssignOpIr),
330 Select(SelectOpIr),
336 SelectAssign(SelectAssignOpIr),
342 MaskWhere(MaskWhereOpIr),
348 MaskFill(MaskFillOpIr),
354 Gather(GatherOpIr),
360 Scatter(ScatterOpIr),
366 ScatterNd(ScatterNdOpIr),
368 GatherNd(GatherNdOpIr),
370 Equal(BinaryOpIr),
376 EqualElem(ScalarOpIr),
382 RepeatDim(RepeatDimOpIr),
388 Cat(CatOpIr),
394 Cast(CastOpIr),
396 Empty(CreationOpIr),
402 Ones(CreationOpIr),
408 Zeros(CreationOpIr),
414}
415
416#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
418pub enum NumericOperationIr {
419 Add(BinaryOpIr),
424 AddScalar(ScalarOpIr),
429 Sub(BinaryOpIr),
434 SubScalar(ScalarOpIr),
439 Div(BinaryOpIr),
444 DivScalar(ScalarOpIr),
449 Rem(BinaryOpIr),
454 RemScalar(ScalarOpIr),
459 Mul(BinaryOpIr),
464 MulScalar(ScalarOpIr),
469 Abs(UnaryOpIr),
474 Full(FullOpIr),
479 MeanDim(ReduceDimOpIr),
484 Mean(ReduceOpIr),
489 Sum(ReduceOpIr),
494 SumDim(ReduceDimOpIr),
499 Prod(ReduceOpIr),
504 ProdDim(ReduceDimOpIr),
509 Greater(BinaryOpIr),
514 GreaterElem(ScalarOpIr),
519 GreaterEqual(BinaryOpIr),
524 GreaterEqualElem(ScalarOpIr),
529 Lower(BinaryOpIr),
534 LowerElem(ScalarOpIr),
539 LowerEqual(BinaryOpIr),
544 LowerEqualElem(ScalarOpIr),
549 ArgMax(ReduceDimOpIr),
554 ArgTopK(ReduceDimOpIr),
559 TopK(ReduceDimOpIr),
564 ArgMin(ReduceDimOpIr),
569 Max(ReduceOpIr),
574 MaxDimWithIndices(ReduceDimWithIndicesOpIr),
579 MinDimWithIndices(ReduceDimWithIndicesOpIr),
584 Min(ReduceOpIr),
589 MaxDim(ReduceDimOpIr),
594 MinDim(ReduceDimOpIr),
599 MaxAbs(ReduceOpIr),
604 MaxAbsDim(ReduceDimOpIr),
609 Clamp(ClampOpIr),
614 IntRandom(RandomOpIr),
618 Powi(BinaryOpIr),
623 PowiScalar(ScalarOpIr),
628 CumSum(DimOpIr),
633 CumProd(DimOpIr),
638 CumMin(DimOpIr),
643 CumMax(DimOpIr),
648}
649
650#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
652pub enum IntOperationIr {
653 IntoFloat(CastOpIr),
655 BitwiseAnd(BinaryOpIr),
659 BitwiseAndScalar(ScalarOpIr),
663 BitwiseOr(BinaryOpIr),
667 BitwiseOrScalar(ScalarOpIr),
671 BitwiseXor(BinaryOpIr),
675 BitwiseXorScalar(ScalarOpIr),
679 BitwiseNot(UnaryOpIr),
683 BitwiseLeftShift(BinaryOpIr),
687 BitwiseLeftShiftScalar(ScalarOpIr),
691 BitwiseRightShift(BinaryOpIr),
695 BitwiseRightShiftScalar(ScalarOpIr),
699 Matmul(MatmulOpIr),
701}
702
703#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
705pub enum BoolOperationIr {
706 IntoFloat(CastOpIr),
708 IntoInt(CastOpIr),
710 Not(UnaryOpIr),
712 And(BinaryOpIr),
714 Or(BinaryOpIr),
716}
717
718#[cfg(feature = "distributed")]
719#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
721pub enum DistributedOperationIr {
722 AllReduce(AllReduceOpIr),
725}
726
727#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
729pub struct SwapDimsOpIr {
730 pub input: TensorIr,
732 pub out: TensorIr,
734 pub dim1: usize,
736 pub dim2: usize,
738}
739
740#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
742pub struct PermuteOpIr {
743 pub input: TensorIr,
745 pub out: TensorIr,
747 pub axes: Vec<usize>,
749}
750
751#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
753pub struct ShapeOpIr {
754 pub input: TensorIr,
756 pub out: TensorIr,
758}
759
760#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
762pub struct UnfoldOpIr {
763 pub input: TensorIr,
765 pub out: TensorIr,
767
768 pub dim: usize,
770 pub size: usize,
772 pub step: usize,
774}
775
776#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
778pub struct FlipOpIr {
779 pub input: TensorIr,
781 pub out: TensorIr,
783 pub axes: Vec<usize>,
785}
786
787#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
788#[allow(missing_docs)]
789pub struct RandomOpIr {
790 pub out: TensorIr,
791 pub distribution: Distribution,
792}
793
794#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
797pub struct CreationOpIr {
798 pub out: TensorIr,
800}
801
802#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
804pub struct FullOpIr {
805 pub out: TensorIr,
807 pub value: ScalarIr,
809}
810
811#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
812pub struct InitOperationIr {
816 pub out: TensorIr,
818}
819
820#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
821#[allow(missing_docs)]
822pub struct BinaryOpIr {
823 pub lhs: TensorIr,
824 pub rhs: TensorIr,
825 pub out: TensorIr,
826}
827
828#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
829#[allow(missing_docs)]
830pub struct MatmulOpIr {
831 pub lhs: TensorIr,
832 pub rhs: TensorIr,
833 pub out: TensorIr,
834}
835
836#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
837#[allow(missing_docs)]
838pub struct CrossOpIr {
839 pub lhs: TensorIr,
840 pub rhs: TensorIr,
841 pub out: TensorIr,
842 pub dim: usize,
843}
844
845#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
846#[allow(missing_docs)]
847pub struct UnaryOpIr {
848 pub input: TensorIr,
849 pub out: TensorIr,
850}
851
852#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
853#[allow(missing_docs)]
854pub struct ScalarOpIr {
855 pub lhs: TensorIr,
856 pub rhs: ScalarIr,
859 pub out: TensorIr,
860}
861
862#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
863#[allow(missing_docs)]
864pub struct ReduceOpIr {
865 pub input: TensorIr,
866 pub out: TensorIr,
867}
868
869#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
870#[allow(missing_docs)]
871pub struct ReduceDimOpIr {
872 pub input: TensorIr,
873 pub out: TensorIr,
874 pub axis: usize,
875 pub accumulator_len: usize,
876}
877
878#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
879#[allow(missing_docs)]
880pub struct CastOpIr {
881 pub input: TensorIr,
882 pub out: TensorIr,
883}
884
885#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
888#[allow(missing_docs)]
889pub struct DimOpIr {
890 pub input: TensorIr,
891 pub out: TensorIr,
892 pub axis: usize,
893}
894
895#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
896#[allow(missing_docs)]
897pub struct GatherOpIr {
898 pub tensor: TensorIr,
899 pub dim: usize,
900 pub indices: TensorIr,
901 pub out: TensorIr,
902}
903
904#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
905#[allow(missing_docs)]
906pub struct ScatterOpIr {
907 pub tensor: TensorIr,
908 pub dim: usize,
909 pub indices: TensorIr,
910 pub value: TensorIr,
911 pub update: IndexingUpdateOp,
912 pub out: TensorIr,
913}
914
915#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
916#[allow(missing_docs)]
917pub struct ScatterNdOpIr {
918 pub data: TensorIr,
919 pub indices: TensorIr,
920 pub values: TensorIr,
921 pub reduction: IndexingUpdateOp,
922 pub out: TensorIr,
923}
924
925#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
926#[allow(missing_docs)]
927pub struct GatherNdOpIr {
928 pub data: TensorIr,
929 pub indices: TensorIr,
930 pub out: TensorIr,
931}
932
933#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
934#[allow(missing_docs)]
935pub struct SelectOpIr {
936 pub tensor: TensorIr,
937 pub dim: usize,
938 pub indices: TensorIr,
939 pub out: TensorIr,
940}
941
942#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
943#[allow(missing_docs)]
944pub struct SelectAssignOpIr {
945 pub tensor: TensorIr,
946 pub dim: usize,
947 pub indices: TensorIr,
948 pub value: TensorIr,
949 pub update: IndexingUpdateOp,
950 pub out: TensorIr,
951}
952
953#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
954#[allow(missing_docs)]
955pub struct SliceOpIr {
956 pub tensor: TensorIr,
957 pub ranges: Vec<Slice>,
958 pub out: TensorIr,
959}
960
961#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
962#[allow(missing_docs)]
963pub struct SliceAssignOpIr {
964 pub tensor: TensorIr,
965 pub ranges: Vec<burn_backend::Slice>,
966 pub value: TensorIr,
967 pub out: TensorIr,
968}
969
970#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
971#[allow(missing_docs)]
972pub struct MaskWhereOpIr {
973 pub tensor: TensorIr,
974 pub mask: TensorIr,
975 pub value: TensorIr,
976 pub out: TensorIr,
977}
978
979#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
980#[allow(missing_docs)]
981pub struct MaskFillOpIr {
982 pub tensor: TensorIr,
983 pub mask: TensorIr,
984 pub value: ScalarIr,
985 pub out: TensorIr,
986}
987
988#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
989#[allow(missing_docs)]
990pub struct ClampOpIr {
991 pub tensor: TensorIr,
992 pub min: ScalarIr,
993 pub max: ScalarIr,
994 pub out: TensorIr,
995}
996
997#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
998#[allow(missing_docs)]
999pub struct RepeatDimOpIr {
1000 pub tensor: TensorIr,
1001 pub dim: usize,
1002 pub times: usize,
1003 pub out: TensorIr,
1004}
1005
1006#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1007#[allow(missing_docs)]
1008pub struct CatOpIr {
1009 pub tensors: Vec<TensorIr>,
1010 pub dim: usize,
1011 pub out: TensorIr,
1012}
1013
1014#[cfg(feature = "distributed")]
1015#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1016#[allow(missing_docs)]
1017pub struct AllReduceOpIr {
1018 pub tensor: TensorIr,
1019 pub out: TensorIr,
1020}
1021
1022#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1023#[allow(missing_docs)]
1024pub struct ReduceDimWithIndicesOpIr {
1025 pub tensor: TensorIr,
1026 pub dim: usize,
1027 pub out: TensorIr,
1028 pub out_indices: TensorIr,
1029}
1030
1031#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1032#[allow(missing_docs)]
1033pub struct EmbeddingOpIr {
1034 pub weights: TensorIr,
1035 pub indices: TensorIr,
1036 pub out: TensorIr,
1037}
1038
1039#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1040#[allow(missing_docs)]
1041pub struct EmbeddingBackwardOpIr {
1042 pub weights: TensorIr,
1043 pub out_grad: TensorIr,
1044 pub indices: TensorIr,
1045 pub out: TensorIr,
1046}
1047
1048#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1049#[allow(missing_docs)]
1050pub struct LinearOpIr {
1051 pub x: TensorIr,
1052 pub weight: TensorIr,
1053 pub bias: Option<TensorIr>,
1054 pub out: TensorIr,
1055}
1056
1057#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1058#[allow(missing_docs)]
1059pub struct LinearXBackwardOpIr {
1060 pub weight: TensorIr,
1061 pub output_grad: TensorIr,
1062 pub out: TensorIr,
1063}
1064
1065#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1066#[allow(missing_docs)]
1067pub struct LinearWeightBackwardOpIr {
1068 pub x: TensorIr,
1069 pub output_grad: TensorIr,
1070 pub out: TensorIr,
1071}
1072
1073#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1074#[allow(missing_docs)]
1075pub struct LinearBiasBackwardOpIr {
1076 pub output_grad: TensorIr,
1077 pub out: TensorIr,
1078}
1079
1080#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1081#[allow(missing_docs)]
1082pub struct Conv1dOpIr {
1083 pub x: TensorIr,
1084 pub weight: TensorIr,
1085 pub bias: Option<TensorIr>,
1086 pub options: Conv1dOptionsIr,
1087 pub out: TensorIr,
1088}
1089
1090#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1091#[allow(missing_docs)]
1092pub struct Conv1dXBackwardOpIr {
1093 pub x: TensorIr,
1094 pub weight: TensorIr,
1095 pub output_grad: TensorIr,
1096 pub options: Conv1dOptionsIr,
1097 pub out: TensorIr,
1098}
1099
1100#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1101#[allow(missing_docs)]
1102pub struct Conv1dWeightBackwardOpIr {
1103 pub x: TensorIr,
1104 pub weight: TensorIr,
1105 pub output_grad: TensorIr,
1106 pub options: Conv1dOptionsIr,
1107 pub out: TensorIr,
1108}
1109
1110#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1111#[allow(missing_docs)]
1112pub struct Conv1dBiasBackwardOpIr {
1113 pub x: TensorIr,
1114 pub bias: TensorIr,
1115 pub output_grad: TensorIr,
1116 pub out: TensorIr,
1117}
1118
1119#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1120#[allow(missing_docs)]
1121pub struct Conv2dOpIr {
1122 pub x: TensorIr,
1123 pub weight: TensorIr,
1124 pub bias: Option<TensorIr>,
1125 pub options: Conv2dOptionsIr,
1126 pub out: TensorIr,
1127}
1128
1129#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1130#[allow(missing_docs)]
1131pub struct Conv2dXBackwardOpIr {
1132 pub x: TensorIr,
1133 pub weight: TensorIr,
1134 pub output_grad: TensorIr,
1135 pub options: Conv2dOptionsIr,
1136 pub out: TensorIr,
1137}
1138
1139#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1140#[allow(missing_docs)]
1141pub struct Conv2dWeightBackwardOpIr {
1142 pub x: TensorIr,
1143 pub weight: TensorIr,
1144 pub output_grad: TensorIr,
1145 pub options: Conv2dOptionsIr,
1146 pub out: TensorIr,
1147}
1148
1149#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1150#[allow(missing_docs)]
1151pub struct Conv2dBiasBackwardOpIr {
1152 pub x: TensorIr,
1153 pub bias: TensorIr,
1154 pub output_grad: TensorIr,
1155 pub out: TensorIr,
1156}
1157
1158#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1159#[allow(missing_docs)]
1160pub struct DeformConv2dOpIr {
1161 pub x: TensorIr,
1162 pub offset: TensorIr,
1163 pub weight: TensorIr,
1164 pub mask: Option<TensorIr>,
1165 pub bias: Option<TensorIr>,
1166 pub options: DeformableConv2dOptionsIr,
1167 pub out: TensorIr,
1168}
1169
1170#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1171#[allow(missing_docs)]
1172pub struct DeformConv2dBackwardOpIr {
1173 pub x: TensorIr,
1174 pub offset: TensorIr,
1175 pub weight: TensorIr,
1176 pub mask: Option<TensorIr>,
1177 pub bias: Option<TensorIr>,
1178 pub out_grad: TensorIr,
1179 pub options: DeformableConv2dOptionsIr,
1180 pub input_grad: TensorIr,
1181 pub offset_grad: TensorIr,
1182 pub weight_grad: TensorIr,
1183 pub mask_grad: Option<TensorIr>,
1184 pub bias_grad: Option<TensorIr>,
1185}
1186
1187#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1188#[allow(missing_docs)]
1189pub struct Conv3dOpIr {
1190 pub x: TensorIr,
1191 pub weight: TensorIr,
1192 pub bias: Option<TensorIr>,
1193 pub options: Conv3dOptionsIr,
1194 pub out: TensorIr,
1195}
1196
1197#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1198#[allow(missing_docs)]
1199pub struct Conv3dXBackwardOpIr {
1200 pub x: TensorIr,
1201 pub weight: TensorIr,
1202 pub output_grad: TensorIr,
1203 pub options: Conv3dOptionsIr,
1204 pub out: TensorIr,
1205}
1206
1207#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1208#[allow(missing_docs)]
1209pub struct Conv3dWeightBackwardOpIr {
1210 pub x: TensorIr,
1211 pub weight: TensorIr,
1212 pub output_grad: TensorIr,
1213 pub options: Conv3dOptionsIr,
1214 pub out: TensorIr,
1215}
1216
1217#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1218#[allow(missing_docs)]
1219pub struct Conv3dBiasBackwardOpIr {
1220 pub x: TensorIr,
1221 pub bias: TensorIr,
1222 pub output_grad: TensorIr,
1223 pub out: TensorIr,
1224}
1225
1226#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1227#[allow(missing_docs)]
1228pub struct ConvTranspose1dOpIr {
1229 pub x: TensorIr,
1230 pub weight: TensorIr,
1231 pub bias: Option<TensorIr>,
1232 pub options: ConvTranspose1dOptionsIr,
1233 pub out: TensorIr,
1234}
1235
1236#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1237#[allow(missing_docs)]
1238pub struct ConvTranspose2dOpIr {
1239 pub x: TensorIr,
1240 pub weight: TensorIr,
1241 pub bias: Option<TensorIr>,
1242 pub options: ConvTranspose2dOptionsIr,
1243 pub out: TensorIr,
1244}
1245
1246#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1247#[allow(missing_docs)]
1248pub struct ConvTranspose3dOpIr {
1249 pub x: TensorIr,
1250 pub weight: TensorIr,
1251 pub bias: Option<TensorIr>,
1252 pub options: ConvTranspose3dOptionsIr,
1253 pub out: TensorIr,
1254}
1255
1256#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1257#[allow(missing_docs)]
1258pub struct Conv1dOptionsIr {
1259 pub stride: [usize; 1],
1260 pub padding: [usize; 1],
1261 pub dilation: [usize; 1],
1262 pub groups: usize,
1263}
1264
1265#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1266#[allow(missing_docs)]
1267pub struct Conv2dOptionsIr {
1268 pub stride: [usize; 2],
1269 pub padding: [usize; 2],
1270 pub dilation: [usize; 2],
1271 pub groups: usize,
1272}
1273
1274#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1275#[allow(missing_docs)]
1276pub struct DeformableConv2dOptionsIr {
1277 pub stride: [usize; 2],
1278 pub padding: [usize; 2],
1279 pub dilation: [usize; 2],
1280 pub weight_groups: usize,
1281 pub offset_groups: usize,
1282}
1283
1284#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1285#[allow(missing_docs)]
1286pub struct Conv3dOptionsIr {
1287 pub stride: [usize; 3],
1288 pub padding: [usize; 3],
1289 pub dilation: [usize; 3],
1290 pub groups: usize,
1291}
1292
1293#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1294#[allow(missing_docs)]
1295pub struct ConvTranspose1dOptionsIr {
1296 pub stride: [usize; 1],
1297 pub padding: [usize; 1],
1298 pub padding_out: [usize; 1],
1299 pub dilation: [usize; 1],
1300 pub groups: usize,
1301}
1302
1303#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1304#[allow(missing_docs)]
1305pub struct ConvTranspose2dOptionsIr {
1306 pub stride: [usize; 2],
1307 pub padding: [usize; 2],
1308 pub padding_out: [usize; 2],
1309 pub dilation: [usize; 2],
1310 pub groups: usize,
1311}
1312
1313#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1314#[allow(missing_docs)]
1315pub struct ConvTranspose3dOptionsIr {
1316 pub stride: [usize; 3],
1317 pub padding: [usize; 3],
1318 pub padding_out: [usize; 3],
1319 pub dilation: [usize; 3],
1320 pub groups: usize,
1321}
1322
1323#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
1325pub struct QuantizationParametersIr {
1326 pub scales: TensorIr,
1328}
1329
1330#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1331#[allow(missing_docs)]
1332pub struct QuantizeOpIr {
1333 pub tensor: TensorIr,
1334 pub qparams: QuantizationParametersIr,
1335 pub scheme: QuantScheme,
1336 pub out: TensorIr,
1337}
1338
1339#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1340#[allow(missing_docs)]
1341pub struct DequantizeOpIr {
1342 pub input: TensorIr,
1343 pub out: TensorIr,
1344}
1345
1346impl From<ConvOptions<1>> for Conv1dOptionsIr {
1347 fn from(value: ConvOptions<1>) -> Self {
1348 Self {
1349 stride: value.stride,
1350 padding: value.padding,
1351 dilation: value.dilation,
1352 groups: value.groups,
1353 }
1354 }
1355}
1356
1357impl From<ConvOptions<2>> for Conv2dOptionsIr {
1358 fn from(value: ConvOptions<2>) -> Self {
1359 Self {
1360 stride: value.stride,
1361 padding: value.padding,
1362 dilation: value.dilation,
1363 groups: value.groups,
1364 }
1365 }
1366}
1367
1368impl From<ConvOptions<3>> for Conv3dOptionsIr {
1369 fn from(value: ConvOptions<3>) -> Self {
1370 Self {
1371 stride: value.stride,
1372 padding: value.padding,
1373 dilation: value.dilation,
1374 groups: value.groups,
1375 }
1376 }
1377}
1378
1379impl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {
1380 fn from(value: DeformConvOptions<2>) -> Self {
1381 Self {
1382 stride: value.stride,
1383 padding: value.padding,
1384 dilation: value.dilation,
1385 weight_groups: value.weight_groups,
1386 offset_groups: value.offset_groups,
1387 }
1388 }
1389}
1390
1391impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {
1392 fn from(value: ConvTransposeOptions<1>) -> Self {
1393 Self {
1394 stride: value.stride,
1395 padding: value.padding,
1396 padding_out: value.padding_out,
1397 dilation: value.dilation,
1398 groups: value.groups,
1399 }
1400 }
1401}
1402
1403impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {
1404 fn from(value: ConvTransposeOptions<2>) -> Self {
1405 Self {
1406 stride: value.stride,
1407 padding: value.padding,
1408 padding_out: value.padding_out,
1409 dilation: value.dilation,
1410 groups: value.groups,
1411 }
1412 }
1413}
1414
1415impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {
1416 fn from(value: ConvTransposeOptions<3>) -> Self {
1417 Self {
1418 stride: value.stride,
1419 padding: value.padding,
1420 padding_out: value.padding_out,
1421 dilation: value.dilation,
1422 groups: value.groups,
1423 }
1424 }
1425}
1426
1427impl From<Conv1dOptionsIr> for ConvOptions<1> {
1428 fn from(val: Conv1dOptionsIr) -> Self {
1429 ConvOptions {
1430 stride: val.stride,
1431 padding: val.padding,
1432 dilation: val.dilation,
1433 groups: val.groups,
1434 }
1435 }
1436}
1437
1438impl From<Conv2dOptionsIr> for ConvOptions<2> {
1439 fn from(val: Conv2dOptionsIr) -> Self {
1440 ConvOptions {
1441 stride: val.stride,
1442 padding: val.padding,
1443 dilation: val.dilation,
1444 groups: val.groups,
1445 }
1446 }
1447}
1448
1449impl From<Conv3dOptionsIr> for ConvOptions<3> {
1450 fn from(val: Conv3dOptionsIr) -> Self {
1451 ConvOptions {
1452 stride: val.stride,
1453 padding: val.padding,
1454 dilation: val.dilation,
1455 groups: val.groups,
1456 }
1457 }
1458}
1459
1460impl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {
1461 fn from(value: DeformableConv2dOptionsIr) -> Self {
1462 DeformConvOptions {
1463 stride: value.stride,
1464 padding: value.padding,
1465 dilation: value.dilation,
1466 weight_groups: value.weight_groups,
1467 offset_groups: value.offset_groups,
1468 }
1469 }
1470}
1471
1472impl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {
1473 fn from(val: ConvTranspose1dOptionsIr) -> Self {
1474 ConvTransposeOptions {
1475 stride: val.stride,
1476 padding: val.padding,
1477 padding_out: val.padding_out,
1478 dilation: val.dilation,
1479 groups: val.groups,
1480 }
1481 }
1482}
1483
1484impl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {
1485 fn from(val: ConvTranspose2dOptionsIr) -> Self {
1486 ConvTransposeOptions {
1487 stride: val.stride,
1488 padding: val.padding,
1489 padding_out: val.padding_out,
1490 dilation: val.dilation,
1491 groups: val.groups,
1492 }
1493 }
1494}
1495
1496impl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {
1497 fn from(val: ConvTranspose3dOptionsIr) -> Self {
1498 ConvTransposeOptions {
1499 stride: val.stride,
1500 padding: val.padding,
1501 padding_out: val.padding_out,
1502 dilation: val.dilation,
1503 groups: val.groups,
1504 }
1505 }
1506}
1507
1508#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1509#[allow(missing_docs)]
1510pub struct AvgPool1dOpIr {
1511 pub x: TensorIr,
1512 pub kernel_size: usize,
1513 pub stride: usize,
1514 pub padding: usize,
1515 pub count_include_pad: bool,
1516 pub ceil_mode: bool,
1517 pub out: TensorIr,
1518}
1519
1520#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1521#[allow(missing_docs)]
1522pub struct AvgPool2dOpIr {
1523 pub x: TensorIr,
1524 pub kernel_size: [usize; 2],
1525 pub stride: [usize; 2],
1526 pub padding: [usize; 2],
1527 pub count_include_pad: bool,
1528 pub ceil_mode: bool,
1529 pub out: TensorIr,
1530}
1531
1532#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1533#[allow(missing_docs)]
1534pub struct AvgPool1dBackwardOpIr {
1535 pub x: TensorIr,
1536 pub grad: TensorIr,
1537 pub kernel_size: usize,
1538 pub stride: usize,
1539 pub padding: usize,
1540 pub count_include_pad: bool,
1541 pub ceil_mode: bool,
1542 pub out: TensorIr,
1543}
1544
1545#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1546#[allow(missing_docs)]
1547pub struct AvgPool2dBackwardOpIr {
1548 pub x: TensorIr,
1549 pub grad: TensorIr,
1550 pub kernel_size: [usize; 2],
1551 pub stride: [usize; 2],
1552 pub padding: [usize; 2],
1553 pub count_include_pad: bool,
1554 pub ceil_mode: bool,
1555 pub out: TensorIr,
1556}
1557
1558#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1559#[allow(missing_docs)]
1560pub struct AdaptiveAvgPool1dOpIr {
1561 pub x: TensorIr,
1562 pub output_size: usize,
1563 pub out: TensorIr,
1564}
1565
1566#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1567#[allow(missing_docs)]
1568pub struct AdaptiveAvgPool2dOpIr {
1569 pub x: TensorIr,
1570 pub output_size: [usize; 2],
1571 pub out: TensorIr,
1572}
1573
1574#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1575#[allow(missing_docs)]
1576pub struct AdaptiveAvgPool1dBackwardOpIr {
1577 pub x: TensorIr,
1578 pub grad: TensorIr,
1579 pub out: TensorIr,
1580}
1581
1582#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1583#[allow(missing_docs)]
1584pub struct AdaptiveAvgPool2dBackwardOpIr {
1585 pub x: TensorIr,
1586 pub grad: TensorIr,
1587 pub out: TensorIr,
1588}
1589
1590#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1591#[allow(missing_docs)]
1592pub struct MaxPool1dOpIr {
1593 pub x: TensorIr,
1594 pub kernel_size: usize,
1595 pub stride: usize,
1596 pub padding: usize,
1597 pub dilation: usize,
1598 pub ceil_mode: bool,
1599 pub out: TensorIr,
1600}
1601
1602#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1603#[allow(missing_docs)]
1604pub struct MaxPool1dWithIndicesOpIr {
1605 pub x: TensorIr,
1606 pub kernel_size: usize,
1607 pub stride: usize,
1608 pub padding: usize,
1609 pub dilation: usize,
1610 pub ceil_mode: bool,
1611 pub out: TensorIr,
1612 pub out_indices: TensorIr,
1613}
1614
1615#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1616#[allow(missing_docs)]
1617pub struct MaxPool1dWithIndicesBackwardOpIr {
1618 pub x: TensorIr,
1619 pub grad: TensorIr,
1620 pub indices: TensorIr,
1621 pub kernel_size: usize,
1622 pub stride: usize,
1623 pub padding: usize,
1624 pub dilation: usize,
1625 pub ceil_mode: bool,
1626 pub out: TensorIr,
1627}
1628
1629#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1630#[allow(missing_docs)]
1631pub struct MaxPool2dOpIr {
1632 pub x: TensorIr,
1633 pub kernel_size: [usize; 2],
1634 pub stride: [usize; 2],
1635 pub padding: [usize; 2],
1636 pub dilation: [usize; 2],
1637 pub ceil_mode: bool,
1638 pub out: TensorIr,
1639}
1640
1641#[allow(missing_docs)]
1642#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1643pub struct MaxPool2dWithIndicesOpIr {
1644 pub x: TensorIr,
1645 pub kernel_size: [usize; 2],
1646 pub stride: [usize; 2],
1647 pub padding: [usize; 2],
1648 pub dilation: [usize; 2],
1649 pub ceil_mode: bool,
1650 pub out: TensorIr,
1651 pub out_indices: TensorIr,
1652}
1653
1654#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1655#[allow(missing_docs)]
1656pub struct MaxPool2dWithIndicesBackwardOpIr {
1657 pub x: TensorIr,
1658 pub grad: TensorIr,
1659 pub indices: TensorIr,
1660 pub kernel_size: [usize; 2],
1661 pub stride: [usize; 2],
1662 pub padding: [usize; 2],
1663 pub dilation: [usize; 2],
1664 pub ceil_mode: bool,
1665 pub out: TensorIr,
1666}
1667
1668#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1669#[allow(missing_docs)]
1670pub enum InterpolateModeIr {
1671 Nearest,
1672 Bilinear,
1673 Bicubic,
1674 Lanczos3,
1675}
1676
1677#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1678#[allow(missing_docs)]
1679pub struct InterpolateOptionsIr {
1680 pub mode: InterpolateModeIr,
1681 pub align_corners: bool,
1682}
1683
1684#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1685#[allow(missing_docs)]
1686pub struct InterpolateOpIr {
1687 pub x: TensorIr,
1688 pub output_size: [usize; 2],
1689 pub options: InterpolateOptionsIr,
1690 pub out: TensorIr,
1691}
1692
1693#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1694#[allow(missing_docs)]
1695pub struct RfftOpIr {
1696 pub signal: TensorIr,
1697 pub dim: usize,
1698 pub n: Option<usize>,
1699 pub out_re: TensorIr,
1700 pub out_im: TensorIr,
1701}
1702
1703#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1704#[allow(missing_docs)]
1705pub struct IRfftOpIr {
1706 pub input_re: TensorIr,
1707 pub input_im: TensorIr,
1708 pub dim: usize,
1709 pub n: Option<usize>,
1710 pub out_signal: TensorIr,
1711}
1712
1713#[allow(missing_docs)]
1714impl RfftOpIr {
1715 pub fn create<F>(signal: TensorIr, dim: usize, n: Option<usize>, mut new_id: F) -> Self
1716 where
1717 F: FnMut() -> crate::TensorId,
1718 {
1719 let mut shape = signal.shape.clone();
1722 let fft_len = n.unwrap_or(shape[dim]);
1723 shape[dim] = fft_len / 2 + 1;
1724 let dtype = signal.dtype;
1725
1726 Self {
1727 signal,
1728 dim,
1729 n,
1730 out_re: TensorIr::uninit(new_id(), shape.clone(), dtype),
1731 out_im: TensorIr::uninit(new_id(), shape, dtype),
1732 }
1733 }
1734}
1735
1736#[allow(missing_docs)]
1737impl IRfftOpIr {
1738 pub fn create<F>(
1739 input_re: TensorIr,
1740 input_im: TensorIr,
1741 dim: usize,
1742 n: Option<usize>,
1743 mut new_id: F,
1744 ) -> Self
1745 where
1746 F: FnMut() -> crate::TensorId,
1747 {
1748 debug_assert!(
1749 input_re.shape[dim] >= 1,
1750 "IRfftOpIr: input spectrum dimension must be >= 1"
1751 );
1752 debug_assert!(
1753 !matches!(n, Some(0)),
1754 "IRfftOpIr: n must be >= 1 when specified"
1755 );
1756 let mut shape = input_re.shape.clone();
1757 shape[dim] = n.unwrap_or((shape[dim] - 1) * 2);
1758 let dtype = input_re.dtype;
1759
1760 Self {
1761 input_re,
1762 input_im,
1763 dim,
1764 n,
1765 out_signal: TensorIr::uninit(new_id(), shape, dtype),
1766 }
1767 }
1768}
1769
1770#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1771#[allow(missing_docs)]
1772pub struct AttentionOptionsIr {
1773 pub scale: Option<ScalarIr>,
1774 pub softcap: Option<ScalarIr>,
1775 pub is_causal: bool,
1776}
1777
1778impl From<AttentionOptionsIr> for AttentionModuleOptions {
1779 fn from(ir: AttentionOptionsIr) -> Self {
1780 AttentionModuleOptions {
1781 scale: ir.scale.map(|s| s.elem()),
1782 softcap: ir.softcap.map(|s| s.elem()),
1783 is_causal: ir.is_causal,
1784 }
1785 }
1786}
1787
1788impl From<AttentionModuleOptions> for AttentionOptionsIr {
1789 fn from(ir: AttentionModuleOptions) -> Self {
1790 AttentionOptionsIr {
1791 scale: ir.scale.map(ScalarIr::Float),
1792 softcap: ir.softcap.map(ScalarIr::Float),
1793 is_causal: ir.is_causal,
1794 }
1795 }
1796}
1797
1798#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1799#[allow(missing_docs)]
1800pub struct AttentionOpIr {
1801 pub query: TensorIr,
1802 pub key: TensorIr,
1803 pub value: TensorIr,
1804 pub mask: Option<TensorIr>,
1805 pub attn_bias: Option<TensorIr>,
1806 pub options: AttentionOptionsIr,
1807 pub out: TensorIr,
1808}
1809
1810#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1811#[allow(missing_docs)]
1812pub struct CtcLossOpIr {
1813 pub log_probs: TensorIr,
1814 pub targets: TensorIr,
1815 pub input_lengths: TensorIr,
1816 pub target_lengths: TensorIr,
1817 pub blank: usize,
1818 pub out: TensorIr,
1819}
1820
1821#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1822#[allow(missing_docs)]
1823pub struct CtcLossBackwardOpIr {
1824 pub log_probs: TensorIr,
1825 pub targets: TensorIr,
1826 pub input_lengths: TensorIr,
1827 pub target_lengths: TensorIr,
1828 pub grad_loss: TensorIr,
1829 pub blank: usize,
1830 pub out: TensorIr,
1831}
1832
1833impl From<InterpolateModeIr> for InterpolateMode {
1834 fn from(val: InterpolateModeIr) -> Self {
1835 match val {
1836 InterpolateModeIr::Nearest => Self::Nearest,
1837 InterpolateModeIr::Bilinear => Self::Bilinear,
1838 InterpolateModeIr::Bicubic => Self::Bicubic,
1839 InterpolateModeIr::Lanczos3 => Self::Lanczos3,
1840 }
1841 }
1842}
1843
1844impl From<InterpolateOptionsIr> for InterpolateOptions {
1845 fn from(val: InterpolateOptionsIr) -> Self {
1846 Self::new(val.mode.into()).with_align_corners(val.align_corners)
1847 }
1848}
1849
1850impl From<InterpolateMode> for InterpolateModeIr {
1851 fn from(val: InterpolateMode) -> Self {
1852 match val {
1853 InterpolateMode::Nearest => Self::Nearest,
1854 InterpolateMode::Bilinear => Self::Bilinear,
1855 InterpolateMode::Bicubic => Self::Bicubic,
1856 InterpolateMode::Lanczos3 => Self::Lanczos3,
1857 }
1858 }
1859}
1860
1861impl From<InterpolateOptions> for InterpolateOptionsIr {
1862 fn from(val: InterpolateOptions) -> Self {
1863 Self {
1864 mode: val.mode.into(),
1865 align_corners: val.align_corners,
1866 }
1867 }
1868}
1869
1870#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1871#[allow(missing_docs)]
1872pub struct InterpolateBackwardOpIr {
1873 pub x: TensorIr,
1874 pub grad: TensorIr,
1875 pub output_size: [usize; 2],
1876 pub options: InterpolateOptionsIr,
1877 pub out: TensorIr,
1878}
1879
1880#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1881#[allow(missing_docs)]
1882pub enum GridSamplePaddingModeIr {
1883 Zeros,
1884 Border,
1885 Reflection,
1886}
1887
1888#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1889#[allow(missing_docs)]
1890pub struct GridSampleOptionsIr {
1891 pub mode: InterpolateModeIr,
1892 pub padding_mode: GridSamplePaddingModeIr,
1893 pub align_corners: bool,
1894}
1895
1896#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1897#[allow(missing_docs)]
1898pub struct GridSample2dOpIr {
1899 pub tensor: TensorIr,
1900 pub grid: TensorIr,
1901 pub options: GridSampleOptionsIr,
1902 pub out: TensorIr,
1903}
1904
1905impl From<GridSamplePaddingModeIr> for GridSamplePaddingMode {
1906 fn from(val: GridSamplePaddingModeIr) -> Self {
1907 match val {
1908 GridSamplePaddingModeIr::Zeros => Self::Zeros,
1909 GridSamplePaddingModeIr::Border => Self::Border,
1910 GridSamplePaddingModeIr::Reflection => Self::Reflection,
1911 }
1912 }
1913}
1914
1915impl From<GridSamplePaddingMode> for GridSamplePaddingModeIr {
1916 fn from(val: GridSamplePaddingMode) -> Self {
1917 match val {
1918 GridSamplePaddingMode::Zeros => Self::Zeros,
1919 GridSamplePaddingMode::Border => Self::Border,
1920 GridSamplePaddingMode::Reflection => Self::Reflection,
1921 }
1922 }
1923}
1924
1925impl From<GridSampleOptionsIr> for GridSampleOptions {
1926 fn from(val: GridSampleOptionsIr) -> Self {
1927 Self {
1928 mode: val.mode.into(),
1929 padding_mode: val.padding_mode.into(),
1930 align_corners: val.align_corners,
1931 }
1932 }
1933}
1934
1935impl From<GridSampleOptions> for GridSampleOptionsIr {
1936 fn from(val: GridSampleOptions) -> Self {
1937 Self {
1938 mode: val.mode.into(),
1939 padding_mode: val.padding_mode.into(),
1940 align_corners: val.align_corners,
1941 }
1942 }
1943}
1944
1945impl OperationIr {
1946 pub fn inputs(&self) -> impl Iterator<Item = &TensorIr> {
1948 match self {
1949 OperationIr::BaseFloat(repr) => repr.inputs(),
1950 OperationIr::BaseInt(repr) => repr.inputs(),
1951 OperationIr::BaseBool(repr) => repr.inputs(),
1952 OperationIr::NumericFloat(_dtype, repr) => repr.inputs(),
1953 OperationIr::NumericInt(_dtype, repr) => repr.inputs(),
1954 OperationIr::Bool(repr) => repr.inputs(),
1955 OperationIr::Int(repr) => repr.inputs(),
1956 OperationIr::Float(_dtype, repr) => repr.inputs(),
1957 OperationIr::Module(repr) => repr.inputs(),
1958 OperationIr::Init(repr) => repr.inputs(),
1959 OperationIr::Custom(repr) => repr.inputs(),
1960 OperationIr::Drop(repr) => Box::new([repr].into_iter()),
1961 #[cfg(feature = "distributed")]
1962 OperationIr::Distributed(repr) => repr.inputs(),
1963 }
1964 }
1965
1966 pub fn outputs(&self) -> impl Iterator<Item = &TensorIr> {
1968 match self {
1969 OperationIr::BaseFloat(repr) => repr.outputs(),
1970 OperationIr::BaseInt(repr) => repr.outputs(),
1971 OperationIr::BaseBool(repr) => repr.outputs(),
1972 OperationIr::NumericFloat(_dtype, repr) => repr.outputs(),
1973 OperationIr::NumericInt(_dtype, repr) => repr.outputs(),
1974 OperationIr::Bool(repr) => repr.outputs(),
1975 OperationIr::Int(repr) => repr.outputs(),
1976 OperationIr::Float(_dtype, repr) => repr.outputs(),
1977 OperationIr::Module(repr) => repr.outputs(),
1978 OperationIr::Init(repr) => repr.outputs(),
1979 OperationIr::Custom(repr) => repr.outputs(),
1980 OperationIr::Drop(_repr) => Box::new([].into_iter()),
1981 #[cfg(feature = "distributed")]
1982 OperationIr::Distributed(repr) => repr.outputs(),
1983 }
1984 }
1985
1986 pub fn nodes(&self) -> Vec<&TensorIr> {
1988 self.inputs().chain(self.outputs()).collect()
1989 }
1990
1991 pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1996 match self {
1997 OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),
1998 OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),
1999 OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),
2000 OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),
2001 OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),
2002 OperationIr::Bool(repr) => repr.mark_read_only(nodes),
2003 OperationIr::Int(repr) => repr.mark_read_only(nodes),
2004 OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),
2005 OperationIr::Module(repr) => repr.mark_read_only(nodes),
2006 OperationIr::Init(_) => Vec::new(),
2007 OperationIr::Drop(repr) => {
2008 let mut output = Vec::new();
2009 repr.mark_read_only(nodes, &mut output);
2010 output
2011 }
2012 OperationIr::Custom(repr) => {
2013 let mut output = Vec::new();
2014
2015 for input in repr.inputs.iter_mut() {
2016 input.mark_read_only(nodes, &mut output);
2017 }
2018
2019 output
2020 }
2021 #[cfg(feature = "distributed")]
2022 OperationIr::Distributed(repr) => repr.mark_read_only(nodes),
2023 }
2024 }
2025}
2026
2027impl BaseOperationIr {
2028 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2029 match self {
2030 BaseOperationIr::Reshape(repr) => Box::new([&repr.input].into_iter()),
2031 BaseOperationIr::SwapDims(repr) => Box::new([&repr.input].into_iter()),
2032 BaseOperationIr::Permute(repr) => Box::new([&repr.input].into_iter()),
2033 BaseOperationIr::Expand(repr) => Box::new([&repr.input].into_iter()),
2034 BaseOperationIr::Flip(repr) => Box::new([&repr.input].into_iter()),
2035 BaseOperationIr::Slice(repr) => Box::new([&repr.tensor].into_iter()),
2036 BaseOperationIr::SliceAssign(repr) => Box::new([&repr.tensor, &repr.value].into_iter()),
2037 BaseOperationIr::Gather(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
2038 BaseOperationIr::Scatter(repr) => {
2039 Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
2040 }
2041 BaseOperationIr::ScatterNd(repr) => {
2042 Box::new([&repr.data, &repr.indices, &repr.values].into_iter())
2043 }
2044 BaseOperationIr::GatherNd(repr) => Box::new([&repr.data, &repr.indices].into_iter()),
2045 BaseOperationIr::Select(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
2046 BaseOperationIr::SelectAssign(repr) => {
2047 Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
2048 }
2049 BaseOperationIr::MaskWhere(repr) => {
2050 Box::new([&repr.tensor, &repr.mask, &repr.value].into_iter())
2051 }
2052 BaseOperationIr::MaskFill(repr) => Box::new([&repr.tensor, &repr.mask].into_iter()),
2053 BaseOperationIr::Equal(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2054 BaseOperationIr::EqualElem(repr) => Box::new([&repr.lhs].into_iter()),
2055 BaseOperationIr::RepeatDim(repr) => Box::new([&repr.tensor].into_iter()),
2056 BaseOperationIr::Cat(repr) => Box::new(repr.tensors.iter()),
2057 BaseOperationIr::Cast(repr) => Box::new([&repr.input].into_iter()),
2058 BaseOperationIr::Unfold(repr) => Box::new([&repr.input].into_iter()),
2059 BaseOperationIr::Empty(_repr) => Box::new([].into_iter()),
2060 BaseOperationIr::Ones(_repr) => Box::new([].into_iter()),
2061 BaseOperationIr::Zeros(_repr) => Box::new([].into_iter()),
2062 }
2063 }
2064
2065 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2066 match self {
2067 BaseOperationIr::Reshape(repr) => Box::new([&repr.out].into_iter()),
2068 BaseOperationIr::SwapDims(repr) => Box::new([&repr.out].into_iter()),
2069 BaseOperationIr::Permute(repr) => Box::new([&repr.out].into_iter()),
2070 BaseOperationIr::Expand(repr) => Box::new([&repr.out].into_iter()),
2071 BaseOperationIr::Flip(repr) => Box::new([&repr.out].into_iter()),
2072 BaseOperationIr::Slice(repr) => Box::new([&repr.out].into_iter()),
2073 BaseOperationIr::SliceAssign(repr) => Box::new([&repr.out].into_iter()),
2074 BaseOperationIr::Gather(repr) => Box::new([&repr.out].into_iter()),
2075 BaseOperationIr::Scatter(repr) => Box::new([&repr.out].into_iter()),
2076 BaseOperationIr::ScatterNd(repr) => Box::new([&repr.out].into_iter()),
2077 BaseOperationIr::GatherNd(repr) => Box::new([&repr.out].into_iter()),
2078 BaseOperationIr::Select(repr) => Box::new([&repr.out].into_iter()),
2079 BaseOperationIr::SelectAssign(repr) => Box::new([&repr.out].into_iter()),
2080 BaseOperationIr::MaskWhere(repr) => Box::new([&repr.out].into_iter()),
2081 BaseOperationIr::MaskFill(repr) => Box::new([&repr.out].into_iter()),
2082 BaseOperationIr::Equal(repr) => Box::new([&repr.out].into_iter()),
2083 BaseOperationIr::EqualElem(repr) => Box::new([&repr.out].into_iter()),
2084 BaseOperationIr::RepeatDim(repr) => Box::new([&repr.out].into_iter()),
2085 BaseOperationIr::Cat(repr) => Box::new([&repr.out].into_iter()),
2086 BaseOperationIr::Cast(repr) => Box::new([&repr.out].into_iter()),
2087 BaseOperationIr::Unfold(repr) => Box::new([&repr.out].into_iter()),
2088 BaseOperationIr::Empty(repr) => Box::new([&repr.out].into_iter()),
2089 BaseOperationIr::Ones(repr) => Box::new([&repr.out].into_iter()),
2090 BaseOperationIr::Zeros(repr) => Box::new([&repr.out].into_iter()),
2091 }
2092 }
2093
2094 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2095 let mut output = Vec::new();
2096
2097 match self {
2098 BaseOperationIr::Reshape(repr) => {
2099 repr.input.mark_read_only(nodes, &mut output);
2100 }
2101 BaseOperationIr::SwapDims(repr) => {
2102 repr.input.mark_read_only(nodes, &mut output);
2103 }
2104 BaseOperationIr::Permute(repr) => {
2105 repr.input.mark_read_only(nodes, &mut output);
2106 }
2107
2108 BaseOperationIr::Expand(repr) => {
2109 repr.input.mark_read_only(nodes, &mut output);
2110 }
2111
2112 BaseOperationIr::Flip(repr) => {
2113 repr.input.mark_read_only(nodes, &mut output);
2114 }
2115 BaseOperationIr::Slice(repr) => {
2116 repr.tensor.mark_read_only(nodes, &mut output);
2117 }
2118 BaseOperationIr::SliceAssign(repr) => {
2119 repr.tensor.mark_read_only(nodes, &mut output);
2120 repr.value.mark_read_only(nodes, &mut output);
2121 }
2122 BaseOperationIr::Gather(repr) => {
2123 repr.tensor.mark_read_only(nodes, &mut output);
2124 repr.indices.mark_read_only(nodes, &mut output);
2125 }
2126 BaseOperationIr::Scatter(repr) => {
2127 repr.tensor.mark_read_only(nodes, &mut output);
2128 repr.indices.mark_read_only(nodes, &mut output);
2129 repr.value.mark_read_only(nodes, &mut output);
2130 }
2131 BaseOperationIr::ScatterNd(repr) => {
2132 repr.data.mark_read_only(nodes, &mut output);
2133 repr.indices.mark_read_only(nodes, &mut output);
2134 repr.values.mark_read_only(nodes, &mut output);
2135 }
2136 BaseOperationIr::GatherNd(repr) => {
2137 repr.data.mark_read_only(nodes, &mut output);
2138 repr.indices.mark_read_only(nodes, &mut output);
2139 }
2140 BaseOperationIr::Select(repr) => {
2141 repr.tensor.mark_read_only(nodes, &mut output);
2142 repr.indices.mark_read_only(nodes, &mut output);
2143 }
2144 BaseOperationIr::SelectAssign(repr) => {
2145 repr.tensor.mark_read_only(nodes, &mut output);
2146 repr.indices.mark_read_only(nodes, &mut output);
2147 repr.value.mark_read_only(nodes, &mut output);
2148 }
2149 BaseOperationIr::MaskWhere(repr) => {
2150 repr.tensor.mark_read_only(nodes, &mut output);
2151 repr.mask.mark_read_only(nodes, &mut output);
2152 repr.value.mark_read_only(nodes, &mut output);
2153 }
2154 BaseOperationIr::MaskFill(repr) => {
2155 repr.tensor.mark_read_only(nodes, &mut output);
2156 repr.mask.mark_read_only(nodes, &mut output);
2157 }
2158 BaseOperationIr::Equal(repr) => {
2159 repr.lhs.mark_read_only(nodes, &mut output);
2160 repr.rhs.mark_read_only(nodes, &mut output);
2161 }
2162 BaseOperationIr::EqualElem(repr) => {
2163 repr.lhs.mark_read_only(nodes, &mut output);
2164 }
2165 BaseOperationIr::RepeatDim(repr) => {
2166 repr.tensor.mark_read_only(nodes, &mut output);
2167 }
2168 BaseOperationIr::Cat(repr) => {
2169 for t in repr.tensors.iter_mut() {
2170 t.mark_read_only(nodes, &mut output);
2171 }
2172 }
2173 BaseOperationIr::Cast(repr) => {
2174 repr.input.mark_read_only(nodes, &mut output);
2175 }
2176 BaseOperationIr::Unfold(repr) => {
2177 repr.input.mark_read_only(nodes, &mut output);
2178 }
2179 BaseOperationIr::Empty(_) => {}
2180 BaseOperationIr::Zeros(_) => {}
2181 BaseOperationIr::Ones(_) => {}
2182 };
2183
2184 output
2185 }
2186}
2187
2188impl NumericOperationIr {
2189 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2190 match self {
2191 NumericOperationIr::Add(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2192 NumericOperationIr::AddScalar(repr) => Box::new([&repr.lhs].into_iter()),
2193 NumericOperationIr::Sub(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2194 NumericOperationIr::SubScalar(repr) => Box::new([&repr.lhs].into_iter()),
2195 NumericOperationIr::Mul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2196 NumericOperationIr::MulScalar(repr) => Box::new([&repr.lhs].into_iter()),
2197 NumericOperationIr::Div(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2198 NumericOperationIr::DivScalar(repr) => Box::new([&repr.lhs].into_iter()),
2199 NumericOperationIr::Rem(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2200 NumericOperationIr::RemScalar(repr) => Box::new([&repr.lhs].into_iter()),
2201 NumericOperationIr::GreaterElem(repr) => Box::new([&repr.lhs].into_iter()),
2202 NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
2203 NumericOperationIr::LowerElem(repr) => Box::new([&repr.lhs].into_iter()),
2204 NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
2205 NumericOperationIr::Greater(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2206 NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2207 NumericOperationIr::Lower(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2208 NumericOperationIr::LowerEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2209 NumericOperationIr::ArgMax(repr) => Box::new([&repr.input].into_iter()),
2210 NumericOperationIr::ArgTopK(repr) => Box::new([&repr.input].into_iter()),
2211 NumericOperationIr::TopK(repr) => Box::new([&repr.input].into_iter()),
2212 NumericOperationIr::ArgMin(repr) => Box::new([&repr.input].into_iter()),
2213 NumericOperationIr::Clamp(repr) => Box::new([&repr.tensor].into_iter()),
2214 NumericOperationIr::Abs(repr) => Box::new([&repr.input].into_iter()),
2215 NumericOperationIr::Full(_repr) => Box::new([].into_iter()),
2216 NumericOperationIr::MeanDim(repr) => Box::new([&repr.input].into_iter()),
2217 NumericOperationIr::Mean(repr) => Box::new([&repr.input].into_iter()),
2218 NumericOperationIr::Sum(repr) => Box::new([&repr.input].into_iter()),
2219 NumericOperationIr::SumDim(repr) => Box::new([&repr.input].into_iter()),
2220 NumericOperationIr::Prod(repr) => Box::new([&repr.input].into_iter()),
2221 NumericOperationIr::ProdDim(repr) => Box::new([&repr.input].into_iter()),
2222 NumericOperationIr::Max(repr) => Box::new([&repr.input].into_iter()),
2223 NumericOperationIr::MaxDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
2224 NumericOperationIr::MinDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
2225 NumericOperationIr::Min(repr) => Box::new([&repr.input].into_iter()),
2226 NumericOperationIr::MaxDim(repr) => Box::new([&repr.input].into_iter()),
2227 NumericOperationIr::MinDim(repr) => Box::new([&repr.input].into_iter()),
2228 NumericOperationIr::MaxAbs(repr) => Box::new([&repr.input].into_iter()),
2229 NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.input].into_iter()),
2230 NumericOperationIr::IntRandom(_repr) => Box::new([].into_iter()),
2231 NumericOperationIr::Powi(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2232 NumericOperationIr::PowiScalar(repr) => Box::new([&repr.lhs].into_iter()),
2233 NumericOperationIr::CumMin(repr) => Box::new([&repr.input].into_iter()),
2234 NumericOperationIr::CumMax(repr) => Box::new([&repr.input].into_iter()),
2235 NumericOperationIr::CumProd(repr) => Box::new([&repr.input].into_iter()),
2236 NumericOperationIr::CumSum(repr) => Box::new([&repr.input].into_iter()),
2237 }
2238 }
2239
2240 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2241 match self {
2242 NumericOperationIr::Add(repr) => Box::new([&repr.out].into_iter()),
2243 NumericOperationIr::AddScalar(repr) => Box::new([&repr.out].into_iter()),
2244 NumericOperationIr::Sub(repr) => Box::new([&repr.out].into_iter()),
2245 NumericOperationIr::SubScalar(repr) => Box::new([&repr.out].into_iter()),
2246 NumericOperationIr::Mul(repr) => Box::new([&repr.out].into_iter()),
2247 NumericOperationIr::MulScalar(repr) => Box::new([&repr.out].into_iter()),
2248 NumericOperationIr::Div(repr) => Box::new([&repr.out].into_iter()),
2249 NumericOperationIr::DivScalar(repr) => Box::new([&repr.out].into_iter()),
2250 NumericOperationIr::Rem(repr) => Box::new([&repr.out].into_iter()),
2251 NumericOperationIr::RemScalar(repr) => Box::new([&repr.out].into_iter()),
2252 NumericOperationIr::GreaterElem(repr) => Box::new([&repr.out].into_iter()),
2253 NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.out].into_iter()),
2254 NumericOperationIr::LowerElem(repr) => Box::new([&repr.out].into_iter()),
2255 NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.out].into_iter()),
2256 NumericOperationIr::Greater(repr) => Box::new([&repr.out].into_iter()),
2257 NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.out].into_iter()),
2258 NumericOperationIr::Lower(repr) => Box::new([&repr.out].into_iter()),
2259 NumericOperationIr::LowerEqual(repr) => Box::new([&repr.out].into_iter()),
2260 NumericOperationIr::ArgMax(repr) => Box::new([&repr.out].into_iter()),
2261 NumericOperationIr::ArgTopK(repr) => Box::new([&repr.out].into_iter()),
2262 NumericOperationIr::TopK(repr) => Box::new([&repr.out].into_iter()),
2263 NumericOperationIr::ArgMin(repr) => Box::new([&repr.out].into_iter()),
2264 NumericOperationIr::Clamp(repr) => Box::new([&repr.out].into_iter()),
2265 NumericOperationIr::Abs(repr) => Box::new([&repr.out].into_iter()),
2266 NumericOperationIr::Full(repr) => Box::new([&repr.out].into_iter()),
2267 NumericOperationIr::MeanDim(repr) => Box::new([&repr.out].into_iter()),
2268 NumericOperationIr::Mean(repr) => Box::new([&repr.out].into_iter()),
2269 NumericOperationIr::Sum(repr) => Box::new([&repr.out].into_iter()),
2270 NumericOperationIr::SumDim(repr) => Box::new([&repr.out].into_iter()),
2271 NumericOperationIr::Prod(repr) => Box::new([&repr.out].into_iter()),
2272 NumericOperationIr::ProdDim(repr) => Box::new([&repr.out].into_iter()),
2273 NumericOperationIr::Max(repr) => Box::new([&repr.out].into_iter()),
2274 NumericOperationIr::MaxDimWithIndices(repr) => {
2275 Box::new([&repr.out, &repr.out_indices].into_iter())
2276 }
2277 NumericOperationIr::MinDimWithIndices(repr) => {
2278 Box::new([&repr.out, &repr.out_indices].into_iter())
2279 }
2280 NumericOperationIr::Min(repr) => Box::new([&repr.out].into_iter()),
2281 NumericOperationIr::MaxDim(repr) => Box::new([&repr.out].into_iter()),
2282 NumericOperationIr::MinDim(repr) => Box::new([&repr.out].into_iter()),
2283 NumericOperationIr::MaxAbs(repr) => Box::new([&repr.out].into_iter()),
2284 NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.out].into_iter()),
2285 NumericOperationIr::IntRandom(repr) => Box::new([&repr.out].into_iter()),
2286 NumericOperationIr::Powi(repr) => Box::new([&repr.out].into_iter()),
2287 NumericOperationIr::PowiScalar(repr) => Box::new([&repr.out].into_iter()),
2288 NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
2289 NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
2290 NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
2291 NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
2292 }
2293 }
2294 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2295 let mut output = Vec::new();
2296
2297 match self {
2298 NumericOperationIr::Add(repr) => {
2299 repr.lhs.mark_read_only(nodes, &mut output);
2300 repr.rhs.mark_read_only(nodes, &mut output);
2301 }
2302 NumericOperationIr::AddScalar(repr) => {
2303 repr.lhs.mark_read_only(nodes, &mut output);
2304 }
2305 NumericOperationIr::Sub(repr) => {
2306 repr.lhs.mark_read_only(nodes, &mut output);
2307 repr.rhs.mark_read_only(nodes, &mut output);
2308 }
2309 NumericOperationIr::SubScalar(repr) => {
2310 repr.lhs.mark_read_only(nodes, &mut output);
2311 }
2312 NumericOperationIr::Mul(repr) => {
2313 repr.lhs.mark_read_only(nodes, &mut output);
2314 repr.rhs.mark_read_only(nodes, &mut output);
2315 }
2316 NumericOperationIr::MulScalar(repr) => {
2317 repr.lhs.mark_read_only(nodes, &mut output);
2318 }
2319 NumericOperationIr::Div(repr) => {
2320 repr.lhs.mark_read_only(nodes, &mut output);
2321 repr.rhs.mark_read_only(nodes, &mut output);
2322 }
2323 NumericOperationIr::DivScalar(repr) => {
2324 repr.lhs.mark_read_only(nodes, &mut output);
2325 }
2326 NumericOperationIr::Rem(repr) => {
2327 repr.lhs.mark_read_only(nodes, &mut output);
2328 repr.rhs.mark_read_only(nodes, &mut output);
2329 }
2330 NumericOperationIr::RemScalar(repr) => {
2331 repr.lhs.mark_read_only(nodes, &mut output);
2332 }
2333 NumericOperationIr::GreaterElem(repr) => {
2334 repr.lhs.mark_read_only(nodes, &mut output);
2335 }
2336 NumericOperationIr::GreaterEqualElem(repr) => {
2337 repr.lhs.mark_read_only(nodes, &mut output);
2338 }
2339 NumericOperationIr::LowerElem(repr) => {
2340 repr.lhs.mark_read_only(nodes, &mut output);
2341 }
2342 NumericOperationIr::LowerEqualElem(repr) => {
2343 repr.lhs.mark_read_only(nodes, &mut output);
2344 }
2345 NumericOperationIr::Greater(repr) => {
2346 repr.lhs.mark_read_only(nodes, &mut output);
2347 repr.rhs.mark_read_only(nodes, &mut output);
2348 }
2349 NumericOperationIr::GreaterEqual(repr) => {
2350 repr.lhs.mark_read_only(nodes, &mut output);
2351 repr.rhs.mark_read_only(nodes, &mut output);
2352 }
2353 NumericOperationIr::Lower(repr) => {
2354 repr.lhs.mark_read_only(nodes, &mut output);
2355 repr.rhs.mark_read_only(nodes, &mut output);
2356 }
2357 NumericOperationIr::LowerEqual(repr) => {
2358 repr.lhs.mark_read_only(nodes, &mut output);
2359 repr.rhs.mark_read_only(nodes, &mut output);
2360 }
2361 NumericOperationIr::ArgMax(repr) => {
2362 repr.input.mark_read_only(nodes, &mut output);
2363 }
2364 NumericOperationIr::ArgTopK(repr) => {
2365 repr.input.mark_read_only(nodes, &mut output);
2366 }
2367 NumericOperationIr::TopK(repr) => {
2368 repr.input.mark_read_only(nodes, &mut output);
2369 }
2370 NumericOperationIr::ArgMin(repr) => {
2371 repr.input.mark_read_only(nodes, &mut output);
2372 }
2373 NumericOperationIr::Clamp(repr) => {
2374 repr.tensor.mark_read_only(nodes, &mut output);
2375 }
2376 NumericOperationIr::Abs(repr) => {
2377 repr.input.mark_read_only(nodes, &mut output);
2378 }
2379 NumericOperationIr::Full(_) => {}
2380 NumericOperationIr::MeanDim(repr) => {
2381 repr.input.mark_read_only(nodes, &mut output);
2382 }
2383 NumericOperationIr::Mean(repr) => {
2384 repr.input.mark_read_only(nodes, &mut output);
2385 }
2386 NumericOperationIr::Sum(repr) => {
2387 repr.input.mark_read_only(nodes, &mut output);
2388 }
2389 NumericOperationIr::SumDim(repr) => {
2390 repr.input.mark_read_only(nodes, &mut output);
2391 }
2392 NumericOperationIr::Prod(repr) => {
2393 repr.input.mark_read_only(nodes, &mut output);
2394 }
2395 NumericOperationIr::ProdDim(repr) => {
2396 repr.input.mark_read_only(nodes, &mut output);
2397 }
2398 NumericOperationIr::Max(repr) => {
2399 repr.input.mark_read_only(nodes, &mut output);
2400 }
2401 NumericOperationIr::MaxDimWithIndices(repr) => {
2402 repr.tensor.mark_read_only(nodes, &mut output);
2403 }
2404 NumericOperationIr::MinDimWithIndices(repr) => {
2405 repr.tensor.mark_read_only(nodes, &mut output);
2406 }
2407 NumericOperationIr::Min(repr) => {
2408 repr.input.mark_read_only(nodes, &mut output);
2409 }
2410 NumericOperationIr::MaxDim(repr) => {
2411 repr.input.mark_read_only(nodes, &mut output);
2412 }
2413 NumericOperationIr::MinDim(repr) => {
2414 repr.input.mark_read_only(nodes, &mut output);
2415 }
2416 NumericOperationIr::MaxAbs(repr) => {
2417 repr.input.mark_read_only(nodes, &mut output);
2418 }
2419 NumericOperationIr::MaxAbsDim(repr) => {
2420 repr.input.mark_read_only(nodes, &mut output);
2421 }
2422 NumericOperationIr::IntRandom(_) => {}
2423 NumericOperationIr::Powi(repr) => {
2424 repr.lhs.mark_read_only(nodes, &mut output);
2425 repr.rhs.mark_read_only(nodes, &mut output);
2426 }
2427 NumericOperationIr::PowiScalar(repr) => {
2428 repr.lhs.mark_read_only(nodes, &mut output);
2429 }
2430 NumericOperationIr::CumSum(repr) => {
2431 repr.input.mark_read_only(nodes, &mut output);
2432 }
2433 NumericOperationIr::CumProd(repr) => {
2434 repr.input.mark_read_only(nodes, &mut output);
2435 }
2436 NumericOperationIr::CumMin(repr) => {
2437 repr.input.mark_read_only(nodes, &mut output);
2438 }
2439 NumericOperationIr::CumMax(repr) => {
2440 repr.input.mark_read_only(nodes, &mut output);
2441 }
2442 };
2443
2444 output
2445 }
2446}
2447
2448impl FloatOperationIr {
2449 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2450 match self {
2451 FloatOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2452 FloatOperationIr::Cross(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2453 FloatOperationIr::Random(_repr) => Box::new([].into_iter()),
2454 FloatOperationIr::Exp(repr) => Box::new([&repr.input].into_iter()),
2455 FloatOperationIr::Log(repr) => Box::new([&repr.input].into_iter()),
2456 FloatOperationIr::Log1p(repr) => Box::new([&repr.input].into_iter()),
2457 FloatOperationIr::Erf(repr) => Box::new([&repr.input].into_iter()),
2458 FloatOperationIr::Recip(repr) => Box::new([&repr.input].into_iter()),
2459 FloatOperationIr::PowfScalar(repr) => Box::new([&repr.lhs].into_iter()),
2460 FloatOperationIr::Sqrt(repr) => Box::new([&repr.input].into_iter()),
2461 FloatOperationIr::Cos(repr) => Box::new([&repr.input].into_iter()),
2462 FloatOperationIr::Sin(repr) => Box::new([&repr.input].into_iter()),
2463 FloatOperationIr::Tanh(repr) => Box::new([&repr.input].into_iter()),
2464 FloatOperationIr::Round(repr) => Box::new([&repr.input].into_iter()),
2465 FloatOperationIr::Floor(repr) => Box::new([&repr.input].into_iter()),
2466 FloatOperationIr::Ceil(repr) => Box::new([&repr.input].into_iter()),
2467 FloatOperationIr::Trunc(repr) => Box::new([&repr.input].into_iter()),
2468 FloatOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2469 FloatOperationIr::Quantize(repr) => {
2470 Box::new([&repr.tensor, &repr.qparams.scales].into_iter())
2471 }
2472 FloatOperationIr::Dequantize(repr) => Box::new([&repr.input].into_iter()),
2473 FloatOperationIr::IsNan(repr) => Box::new([&repr.input].into_iter()),
2474 FloatOperationIr::IsInf(repr) => Box::new([&repr.input].into_iter()),
2475 FloatOperationIr::GridSample2d(repr) => {
2476 Box::new([&repr.tensor, &repr.grid].into_iter())
2477 }
2478 FloatOperationIr::Tan(repr) => Box::new([&repr.input].into_iter()),
2479 FloatOperationIr::Cosh(repr) => Box::new([&repr.input].into_iter()),
2480 FloatOperationIr::Sinh(repr) => Box::new([&repr.input].into_iter()),
2481 FloatOperationIr::ArcCos(repr) => Box::new([&repr.input].into_iter()),
2482 FloatOperationIr::ArcCosh(repr) => Box::new([&repr.input].into_iter()),
2483 FloatOperationIr::ArcSin(repr) => Box::new([&repr.input].into_iter()),
2484 FloatOperationIr::ArcSinh(repr) => Box::new([&repr.input].into_iter()),
2485 FloatOperationIr::ArcTan(repr) => Box::new([&repr.input].into_iter()),
2486 FloatOperationIr::ArcTanh(repr) => Box::new([&repr.input].into_iter()),
2487 FloatOperationIr::ArcTan2(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2488 FloatOperationIr::Powf(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2489 }
2490 }
2491 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2492 match self {
2493 FloatOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2494 FloatOperationIr::Cross(repr) => Box::new([&repr.out].into_iter()),
2495 FloatOperationIr::Random(repr) => Box::new([&repr.out].into_iter()),
2496 FloatOperationIr::Exp(repr) => Box::new([&repr.out].into_iter()),
2497 FloatOperationIr::Log(repr) => Box::new([&repr.out].into_iter()),
2498 FloatOperationIr::Log1p(repr) => Box::new([&repr.out].into_iter()),
2499 FloatOperationIr::Erf(repr) => Box::new([&repr.out].into_iter()),
2500 FloatOperationIr::Recip(repr) => Box::new([&repr.out].into_iter()),
2501 FloatOperationIr::PowfScalar(repr) => Box::new([&repr.out].into_iter()),
2502 FloatOperationIr::Sqrt(repr) => Box::new([&repr.out].into_iter()),
2503 FloatOperationIr::Cos(repr) => Box::new([&repr.out].into_iter()),
2504 FloatOperationIr::Sin(repr) => Box::new([&repr.out].into_iter()),
2505 FloatOperationIr::Tanh(repr) => Box::new([&repr.out].into_iter()),
2506 FloatOperationIr::Round(repr) => Box::new([&repr.out].into_iter()),
2507 FloatOperationIr::Floor(repr) => Box::new([&repr.out].into_iter()),
2508 FloatOperationIr::Ceil(repr) => Box::new([&repr.out].into_iter()),
2509 FloatOperationIr::Trunc(repr) => Box::new([&repr.out].into_iter()),
2510 FloatOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2511 FloatOperationIr::Quantize(repr) => Box::new([&repr.out].into_iter()),
2512 FloatOperationIr::Dequantize(repr) => Box::new([&repr.out].into_iter()),
2513 FloatOperationIr::IsNan(repr) => Box::new([&repr.out].into_iter()),
2514 FloatOperationIr::IsInf(repr) => Box::new([&repr.out].into_iter()),
2515 FloatOperationIr::GridSample2d(repr) => Box::new([&repr.out].into_iter()),
2516 FloatOperationIr::Tan(repr) => Box::new([&repr.out].into_iter()),
2517 FloatOperationIr::Cosh(repr) => Box::new([&repr.out].into_iter()),
2518 FloatOperationIr::Sinh(repr) => Box::new([&repr.out].into_iter()),
2519 FloatOperationIr::ArcCos(repr) => Box::new([&repr.out].into_iter()),
2520 FloatOperationIr::ArcCosh(repr) => Box::new([&repr.out].into_iter()),
2521 FloatOperationIr::ArcSin(repr) => Box::new([&repr.out].into_iter()),
2522 FloatOperationIr::ArcSinh(repr) => Box::new([&repr.out].into_iter()),
2523 FloatOperationIr::ArcTan(repr) => Box::new([&repr.out].into_iter()),
2524 FloatOperationIr::ArcTanh(repr) => Box::new([&repr.out].into_iter()),
2525 FloatOperationIr::ArcTan2(repr) => Box::new([&repr.out].into_iter()),
2526 FloatOperationIr::Powf(repr) => Box::new([&repr.out].into_iter()),
2527 }
2528 }
2529
2530 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2531 let mut output = Vec::new();
2532
2533 match self {
2534 FloatOperationIr::Matmul(repr) => {
2535 repr.lhs.mark_read_only(nodes, &mut output);
2536 repr.rhs.mark_read_only(nodes, &mut output);
2537 }
2538 FloatOperationIr::Cross(repr) => {
2539 repr.lhs.mark_read_only(nodes, &mut output);
2540 repr.rhs.mark_read_only(nodes, &mut output);
2541 }
2542 FloatOperationIr::Random(_) => {}
2543 FloatOperationIr::Exp(repr) => {
2544 repr.input.mark_read_only(nodes, &mut output);
2545 }
2546 FloatOperationIr::Log(repr) => {
2547 repr.input.mark_read_only(nodes, &mut output);
2548 }
2549 FloatOperationIr::Log1p(repr) => {
2550 repr.input.mark_read_only(nodes, &mut output);
2551 }
2552 FloatOperationIr::Erf(repr) => {
2553 repr.input.mark_read_only(nodes, &mut output);
2554 }
2555 FloatOperationIr::Recip(repr) => {
2556 repr.input.mark_read_only(nodes, &mut output);
2557 }
2558 FloatOperationIr::PowfScalar(repr) => {
2559 repr.lhs.mark_read_only(nodes, &mut output);
2560 }
2561 FloatOperationIr::Sqrt(repr) => {
2562 repr.input.mark_read_only(nodes, &mut output);
2563 }
2564 FloatOperationIr::Cos(repr) => {
2565 repr.input.mark_read_only(nodes, &mut output);
2566 }
2567 FloatOperationIr::Sin(repr) => {
2568 repr.input.mark_read_only(nodes, &mut output);
2569 }
2570 FloatOperationIr::Tanh(repr) => {
2571 repr.input.mark_read_only(nodes, &mut output);
2572 }
2573 FloatOperationIr::Round(repr) => {
2574 repr.input.mark_read_only(nodes, &mut output);
2575 }
2576 FloatOperationIr::Floor(repr) => {
2577 repr.input.mark_read_only(nodes, &mut output);
2578 }
2579 FloatOperationIr::Ceil(repr) => {
2580 repr.input.mark_read_only(nodes, &mut output);
2581 }
2582 FloatOperationIr::Trunc(repr) => {
2583 repr.input.mark_read_only(nodes, &mut output);
2584 }
2585 FloatOperationIr::Quantize(repr) => {
2586 repr.tensor.mark_read_only(nodes, &mut output);
2587 repr.qparams.scales.mark_read_only(nodes, &mut output);
2588 }
2589 FloatOperationIr::Dequantize(repr) => {
2590 repr.input.mark_read_only(nodes, &mut output);
2591 }
2592 FloatOperationIr::IntoInt(repr) => {
2593 repr.input.mark_read_only(nodes, &mut output);
2594 }
2595 FloatOperationIr::IsNan(repr) => {
2596 repr.input.mark_read_only(nodes, &mut output);
2597 }
2598 FloatOperationIr::IsInf(repr) => {
2599 repr.input.mark_read_only(nodes, &mut output);
2600 }
2601 FloatOperationIr::GridSample2d(repr) => {
2602 repr.tensor.mark_read_only(nodes, &mut output);
2603 repr.grid.mark_read_only(nodes, &mut output);
2604 }
2605 FloatOperationIr::Tan(repr) => repr.input.mark_read_only(nodes, &mut output),
2606 FloatOperationIr::Cosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2607 FloatOperationIr::Sinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2608 FloatOperationIr::ArcCos(repr) => repr.input.mark_read_only(nodes, &mut output),
2609 FloatOperationIr::ArcCosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2610 FloatOperationIr::ArcSin(repr) => repr.input.mark_read_only(nodes, &mut output),
2611 FloatOperationIr::ArcSinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2612 FloatOperationIr::ArcTan(repr) => repr.input.mark_read_only(nodes, &mut output),
2613 FloatOperationIr::ArcTanh(repr) => repr.input.mark_read_only(nodes, &mut output),
2614 FloatOperationIr::ArcTan2(repr) => {
2615 repr.lhs.mark_read_only(nodes, &mut output);
2616 repr.rhs.mark_read_only(nodes, &mut output);
2617 }
2618 FloatOperationIr::Powf(repr) => {
2619 repr.lhs.mark_read_only(nodes, &mut output);
2620 repr.rhs.mark_read_only(nodes, &mut output);
2621 }
2622 };
2623
2624 output
2625 }
2626}
2627
2628impl IntOperationIr {
2629 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2630 match self {
2631 IntOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2632 IntOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2633 IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2634 IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.lhs].into_iter()),
2635 IntOperationIr::BitwiseOr(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2636 IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.lhs].into_iter()),
2637 IntOperationIr::BitwiseXor(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2638 IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.lhs].into_iter()),
2639 IntOperationIr::BitwiseNot(repr) => Box::new([&repr.input].into_iter()),
2640 IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2641 IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2642 IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2643 IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2644 }
2645 }
2646
2647 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2648 match self {
2649 IntOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2650 IntOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2651 IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.out].into_iter()),
2652 IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.out].into_iter()),
2653 IntOperationIr::BitwiseOr(repr) => Box::new([&repr.out].into_iter()),
2654 IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.out].into_iter()),
2655 IntOperationIr::BitwiseXor(repr) => Box::new([&repr.out].into_iter()),
2656 IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.out].into_iter()),
2657 IntOperationIr::BitwiseNot(repr) => Box::new([&repr.out].into_iter()),
2658 IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.out].into_iter()),
2659 IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2660 IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.out].into_iter()),
2661 IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2662 }
2663 }
2664
2665 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2666 let mut output = Vec::new();
2667
2668 match self {
2669 IntOperationIr::Matmul(repr) => {
2670 repr.lhs.mark_read_only(nodes, &mut output);
2671 repr.rhs.mark_read_only(nodes, &mut output);
2672 }
2673 IntOperationIr::IntoFloat(repr) => {
2674 repr.input.mark_read_only(nodes, &mut output);
2675 }
2676 IntOperationIr::BitwiseAnd(repr) => {
2677 repr.lhs.mark_read_only(nodes, &mut output);
2678 repr.rhs.mark_read_only(nodes, &mut output);
2679 }
2680 IntOperationIr::BitwiseAndScalar(repr) => {
2681 repr.lhs.mark_read_only(nodes, &mut output);
2682 }
2683 IntOperationIr::BitwiseOr(repr) => {
2684 repr.lhs.mark_read_only(nodes, &mut output);
2685 repr.rhs.mark_read_only(nodes, &mut output);
2686 }
2687 IntOperationIr::BitwiseOrScalar(repr) => {
2688 repr.lhs.mark_read_only(nodes, &mut output);
2689 }
2690 IntOperationIr::BitwiseXor(repr) => {
2691 repr.lhs.mark_read_only(nodes, &mut output);
2692 repr.rhs.mark_read_only(nodes, &mut output);
2693 }
2694 IntOperationIr::BitwiseXorScalar(repr) => {
2695 repr.lhs.mark_read_only(nodes, &mut output);
2696 }
2697 IntOperationIr::BitwiseNot(repr) => {
2698 repr.input.mark_read_only(nodes, &mut output);
2699 }
2700 IntOperationIr::BitwiseLeftShift(repr) => {
2701 repr.lhs.mark_read_only(nodes, &mut output);
2702 repr.rhs.mark_read_only(nodes, &mut output);
2703 }
2704 IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2705 repr.lhs.mark_read_only(nodes, &mut output);
2706 }
2707 IntOperationIr::BitwiseRightShift(repr) => {
2708 repr.lhs.mark_read_only(nodes, &mut output);
2709 repr.rhs.mark_read_only(nodes, &mut output);
2710 }
2711 IntOperationIr::BitwiseRightShiftScalar(repr) => {
2712 repr.lhs.mark_read_only(nodes, &mut output);
2713 }
2714 };
2715
2716 output
2717 }
2718}
2719
2720impl BoolOperationIr {
2721 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2722 match self {
2723 BoolOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2724 BoolOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2725 BoolOperationIr::Not(repr) => Box::new([&repr.input].into_iter()),
2726 BoolOperationIr::And(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2727 BoolOperationIr::Or(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2728 }
2729 }
2730 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2731 match self {
2732 BoolOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2733 BoolOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2734 BoolOperationIr::Not(repr) => Box::new([&repr.out].into_iter()),
2735 BoolOperationIr::And(repr) => Box::new([&repr.out].into_iter()),
2736 BoolOperationIr::Or(repr) => Box::new([&repr.out].into_iter()),
2737 }
2738 }
2739 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2740 let mut output = Vec::new();
2741
2742 match self {
2743 BoolOperationIr::IntoFloat(repr) => {
2744 repr.input.mark_read_only(nodes, &mut output);
2745 }
2746 BoolOperationIr::IntoInt(repr) => {
2747 repr.input.mark_read_only(nodes, &mut output);
2748 }
2749 BoolOperationIr::Not(repr) => {
2750 repr.input.mark_read_only(nodes, &mut output);
2751 }
2752 BoolOperationIr::And(repr) => {
2753 repr.lhs.mark_read_only(nodes, &mut output);
2754 repr.rhs.mark_read_only(nodes, &mut output);
2755 }
2756 BoolOperationIr::Or(repr) => {
2757 repr.lhs.mark_read_only(nodes, &mut output);
2758 repr.rhs.mark_read_only(nodes, &mut output);
2759 }
2760 };
2761
2762 output
2763 }
2764}
2765
2766impl ModuleOperationIr {
2767 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2768 match self {
2769 ModuleOperationIr::Embedding(repr) => {
2770 Box::new([&repr.weights, &repr.indices].into_iter())
2771 }
2772 ModuleOperationIr::EmbeddingBackward(repr) => {
2773 Box::new([&repr.weights, &repr.out_grad, &repr.indices].into_iter())
2774 }
2775 ModuleOperationIr::Linear(repr) => {
2776 if let Some(bias) = &repr.bias {
2777 Box::new([&repr.x, &repr.weight, bias].into_iter())
2778 } else {
2779 Box::new([&repr.x, &repr.weight].into_iter())
2780 }
2781 }
2782 ModuleOperationIr::LinearXBackward(repr) => {
2783 Box::new([&repr.weight, &repr.output_grad].into_iter())
2784 }
2785 ModuleOperationIr::LinearWeightBackward(repr) => {
2786 Box::new([&repr.x, &repr.output_grad].into_iter())
2787 }
2788 ModuleOperationIr::LinearBiasBackward(repr) => {
2789 Box::new([&repr.output_grad].into_iter())
2790 }
2791 ModuleOperationIr::Conv1d(repr) => {
2792 if let Some(bias) = &repr.bias {
2793 Box::new([&repr.x, &repr.weight, bias].into_iter())
2794 } else {
2795 Box::new([&repr.x, &repr.weight].into_iter())
2796 }
2797 }
2798 ModuleOperationIr::Conv1dXBackward(repr) => {
2799 Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2800 }
2801 ModuleOperationIr::Conv1dWeightBackward(repr) => {
2802 Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2803 }
2804 ModuleOperationIr::Conv1dBiasBackward(repr) => {
2805 Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())
2806 }
2807 ModuleOperationIr::Conv2d(repr) => {
2808 if let Some(bias) = &repr.bias {
2809 Box::new([&repr.x, &repr.weight, bias].into_iter())
2810 } else {
2811 Box::new([&repr.x, &repr.weight].into_iter())
2812 }
2813 }
2814 ModuleOperationIr::Conv2dXBackward(repr) => {
2815 Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2816 }
2817 ModuleOperationIr::Conv2dWeightBackward(repr) => {
2818 Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2819 }
2820 ModuleOperationIr::Conv2dBiasBackward(repr) => {
2821 Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())
2822 }
2823 ModuleOperationIr::Conv3d(repr) => {
2824 if let Some(bias) = &repr.bias {
2825 Box::new([&repr.x, &repr.weight, bias].into_iter())
2826 } else {
2827 Box::new([&repr.x, &repr.weight].into_iter())
2828 }
2829 }
2830 ModuleOperationIr::Conv3dXBackward(repr) => {
2831 Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2832 }
2833 ModuleOperationIr::Conv3dWeightBackward(repr) => {
2834 Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2835 }
2836 ModuleOperationIr::Conv3dBiasBackward(repr) => {
2837 Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())
2838 }
2839 ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {
2840 (Some(mask), Some(bias)) => {
2841 Box::new([&repr.x, &repr.offset, &repr.weight, mask, bias].into_iter())
2842 }
2843 (Some(mask), None) => {
2844 Box::new([&repr.x, &repr.offset, &repr.weight, mask].into_iter())
2845 }
2846 (None, Some(bias)) => {
2847 Box::new([&repr.x, &repr.offset, &repr.weight, bias].into_iter())
2848 }
2849 (None, None) => Box::new([&repr.x, &repr.offset, &repr.weight].into_iter()),
2850 },
2851 ModuleOperationIr::DeformableConv2dBackward(repr) => match (&repr.mask, &repr.bias) {
2852 (Some(mask), Some(bias)) => Box::new(
2853 [
2854 &repr.x,
2855 &repr.offset,
2856 &repr.weight,
2857 &repr.out_grad,
2858 mask,
2859 bias,
2860 ]
2861 .into_iter(),
2862 ),
2863 (Some(mask), None) => Box::new(
2864 [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask].into_iter(),
2865 ),
2866 (None, Some(bias)) => Box::new(
2867 [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, bias].into_iter(),
2868 ),
2869 (None, None) => {
2870 Box::new([&repr.x, &repr.offset, &repr.weight, &repr.out_grad].into_iter())
2871 }
2872 },
2873 ModuleOperationIr::ConvTranspose1d(repr) => {
2874 if let Some(bias) = &repr.bias {
2875 Box::new([&repr.x, &repr.weight, bias].into_iter())
2876 } else {
2877 Box::new([&repr.x, &repr.weight].into_iter())
2878 }
2879 }
2880 ModuleOperationIr::ConvTranspose2d(repr) => {
2881 if let Some(bias) = &repr.bias {
2882 Box::new([&repr.x, &repr.weight, bias].into_iter())
2883 } else {
2884 Box::new([&repr.x, &repr.weight].into_iter())
2885 }
2886 }
2887 ModuleOperationIr::ConvTranspose3d(repr) => {
2888 if let Some(bias) = &repr.bias {
2889 Box::new([&repr.x, &repr.weight, bias].into_iter())
2890 } else {
2891 Box::new([&repr.x, &repr.weight].into_iter())
2892 }
2893 }
2894 ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2895 ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2896 ModuleOperationIr::AvgPool1dBackward(repr) => {
2897 Box::new([&repr.x, &repr.grad].into_iter())
2898 }
2899 ModuleOperationIr::AvgPool2dBackward(repr) => {
2900 Box::new([&repr.x, &repr.grad].into_iter())
2901 }
2902 ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2903 ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2904 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2905 Box::new([&repr.x, &repr.grad].into_iter())
2906 }
2907 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2908 Box::new([&repr.x, &repr.grad].into_iter())
2909 }
2910 ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.x].into_iter()),
2911 ModuleOperationIr::MaxPool1dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2912 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2913 Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2914 }
2915 ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.x].into_iter()),
2916 ModuleOperationIr::MaxPool2dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2917 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2918 Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2919 }
2920 ModuleOperationIr::Interpolate(repr) => Box::new([&repr.x].into_iter()),
2921 ModuleOperationIr::InterpolateBackward(repr) => {
2922 Box::new([&repr.x, &repr.grad].into_iter())
2923 }
2924 ModuleOperationIr::Rfft(repr) => Box::new([&repr.signal].into_iter()),
2925 ModuleOperationIr::IRfft(repr) => {
2926 Box::new([&repr.input_re, &repr.input_im].into_iter())
2927 }
2928 ModuleOperationIr::Attention(repr) => {
2929 if let Some(mask) = &repr.mask {
2930 if let Some(attn_bias) = &repr.attn_bias {
2931 Box::new([&repr.query, &repr.key, &repr.value, mask, attn_bias].into_iter())
2932 } else {
2933 Box::new([&repr.query, &repr.key, &repr.value, mask].into_iter())
2934 }
2935 } else if let Some(attn_bias) = &repr.attn_bias {
2936 Box::new([&repr.query, &repr.key, &repr.value, attn_bias].into_iter())
2937 } else {
2938 Box::new([&repr.query, &repr.key, &repr.value].into_iter())
2939 }
2940 }
2941 ModuleOperationIr::CtcLoss(repr) => Box::new(
2942 [
2943 &repr.log_probs,
2944 &repr.targets,
2945 &repr.input_lengths,
2946 &repr.target_lengths,
2947 ]
2948 .into_iter(),
2949 ),
2950 ModuleOperationIr::CtcLossBackward(repr) => Box::new(
2951 [
2952 &repr.log_probs,
2953 &repr.targets,
2954 &repr.input_lengths,
2955 &repr.target_lengths,
2956 &repr.grad_loss,
2957 ]
2958 .into_iter(),
2959 ),
2960 }
2961 }
2962 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2963 match self {
2964 ModuleOperationIr::Embedding(repr) => Box::new([&repr.out].into_iter()),
2965 ModuleOperationIr::EmbeddingBackward(repr) => Box::new([&repr.out].into_iter()),
2966 ModuleOperationIr::Linear(repr) => Box::new([&repr.out].into_iter()),
2967 ModuleOperationIr::LinearXBackward(repr) => Box::new([&repr.out].into_iter()),
2968 ModuleOperationIr::LinearWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2969 ModuleOperationIr::LinearBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2970 ModuleOperationIr::Conv1d(repr) => Box::new([&repr.out].into_iter()),
2971 ModuleOperationIr::Conv1dXBackward(repr) => Box::new([&repr.out].into_iter()),
2972 ModuleOperationIr::Conv1dWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2973 ModuleOperationIr::Conv1dBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2974 ModuleOperationIr::Conv2d(repr) => Box::new([&repr.out].into_iter()),
2975 ModuleOperationIr::Conv2dXBackward(repr) => Box::new([&repr.out].into_iter()),
2976 ModuleOperationIr::Conv2dWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2977 ModuleOperationIr::Conv2dBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2978 ModuleOperationIr::Conv3d(repr) => Box::new([&repr.out].into_iter()),
2979 ModuleOperationIr::Conv3dXBackward(repr) => Box::new([&repr.out].into_iter()),
2980 ModuleOperationIr::Conv3dWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2981 ModuleOperationIr::Conv3dBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2982 ModuleOperationIr::DeformableConv2d(repr) => Box::new([&repr.out].into_iter()),
2983 ModuleOperationIr::DeformableConv2dBackward(repr) => {
2984 match (&repr.mask_grad, &repr.bias_grad) {
2985 (Some(mask_grad), Some(bias_grad)) => Box::new(
2986 [
2987 &repr.input_grad,
2988 &repr.offset_grad,
2989 &repr.weight_grad,
2990 mask_grad,
2991 bias_grad,
2992 ]
2993 .into_iter(),
2994 ),
2995 (Some(mask_grad), None) => Box::new(
2996 [
2997 &repr.input_grad,
2998 &repr.offset_grad,
2999 &repr.weight_grad,
3000 mask_grad,
3001 ]
3002 .into_iter(),
3003 ),
3004 (None, Some(bias_grad)) => Box::new(
3005 [
3006 &repr.input_grad,
3007 &repr.offset_grad,
3008 &repr.weight_grad,
3009 bias_grad,
3010 ]
3011 .into_iter(),
3012 ),
3013 (None, None) => Box::new(
3014 [&repr.input_grad, &repr.offset_grad, &repr.weight_grad].into_iter(),
3015 ),
3016 }
3017 }
3018 ModuleOperationIr::ConvTranspose1d(repr) => Box::new([&repr.out].into_iter()),
3019 ModuleOperationIr::ConvTranspose2d(repr) => Box::new([&repr.out].into_iter()),
3020 ModuleOperationIr::ConvTranspose3d(repr) => Box::new([&repr.out].into_iter()),
3021 ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.out].into_iter()),
3022 ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.out].into_iter()),
3023 ModuleOperationIr::AvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
3024 ModuleOperationIr::AvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
3025 ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.out].into_iter()),
3026 ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.out].into_iter()),
3027 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
3028 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
3029 ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.out].into_iter()),
3030 ModuleOperationIr::MaxPool1dWithIndices(repr) => {
3031 Box::new([&repr.out, &repr.out_indices].into_iter())
3032 }
3033 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
3034 Box::new([&repr.out].into_iter())
3035 }
3036 ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.out].into_iter()),
3037 ModuleOperationIr::MaxPool2dWithIndices(repr) => {
3038 Box::new([&repr.out, &repr.out_indices].into_iter())
3039 }
3040 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
3041 Box::new([&repr.out].into_iter())
3042 }
3043 ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()),
3044 ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()),
3045 ModuleOperationIr::Rfft(repr) => Box::new([&repr.out_re, &repr.out_im].into_iter()),
3046 ModuleOperationIr::IRfft(repr) => Box::new([&repr.out_signal].into_iter()),
3047 ModuleOperationIr::Attention(repr) => Box::new([&repr.out].into_iter()),
3048 ModuleOperationIr::CtcLoss(repr) => Box::new([&repr.out].into_iter()),
3049 ModuleOperationIr::CtcLossBackward(repr) => Box::new([&repr.out].into_iter()),
3050 }
3051 }
3052
3053 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
3054 let mut output = Vec::new();
3055
3056 match self {
3057 ModuleOperationIr::Embedding(repr) => {
3058 repr.weights.mark_read_only(nodes, &mut output);
3059 repr.indices.mark_read_only(nodes, &mut output);
3060 }
3061 ModuleOperationIr::EmbeddingBackward(repr) => {
3062 repr.weights.mark_read_only(nodes, &mut output);
3063 repr.out_grad.mark_read_only(nodes, &mut output);
3064 repr.indices.mark_read_only(nodes, &mut output);
3065 }
3066 ModuleOperationIr::Linear(repr) => {
3067 repr.x.mark_read_only(nodes, &mut output);
3068 repr.weight.mark_read_only(nodes, &mut output);
3069
3070 if let Some(bias) = &mut repr.bias {
3071 bias.mark_read_only(nodes, &mut output);
3072 }
3073 }
3074 ModuleOperationIr::LinearXBackward(repr) => {
3075 repr.weight.mark_read_only(nodes, &mut output);
3076 repr.output_grad.mark_read_only(nodes, &mut output);
3077 }
3078 ModuleOperationIr::LinearWeightBackward(repr) => {
3079 repr.x.mark_read_only(nodes, &mut output);
3080 repr.output_grad.mark_read_only(nodes, &mut output);
3081 }
3082 ModuleOperationIr::LinearBiasBackward(repr) => {
3083 repr.output_grad.mark_read_only(nodes, &mut output);
3084 }
3085 ModuleOperationIr::Conv1d(repr) => {
3086 repr.x.mark_read_only(nodes, &mut output);
3087 repr.weight.mark_read_only(nodes, &mut output);
3088
3089 if let Some(bias) = &mut repr.bias {
3090 bias.mark_read_only(nodes, &mut output);
3091 }
3092 }
3093 ModuleOperationIr::Conv1dXBackward(repr) => {
3094 repr.x.mark_read_only(nodes, &mut output);
3095 repr.weight.mark_read_only(nodes, &mut output);
3096 repr.output_grad.mark_read_only(nodes, &mut output);
3097 }
3098 ModuleOperationIr::Conv1dWeightBackward(repr) => {
3099 repr.x.mark_read_only(nodes, &mut output);
3100 repr.weight.mark_read_only(nodes, &mut output);
3101 repr.output_grad.mark_read_only(nodes, &mut output);
3102 }
3103 ModuleOperationIr::Conv1dBiasBackward(repr) => {
3104 repr.x.mark_read_only(nodes, &mut output);
3105 repr.bias.mark_read_only(nodes, &mut output);
3106 repr.output_grad.mark_read_only(nodes, &mut output);
3107 }
3108 ModuleOperationIr::Conv2d(repr) => {
3109 repr.x.mark_read_only(nodes, &mut output);
3110 repr.weight.mark_read_only(nodes, &mut output);
3111
3112 if let Some(bias) = &mut repr.bias {
3113 bias.mark_read_only(nodes, &mut output);
3114 }
3115 }
3116 ModuleOperationIr::Conv2dXBackward(repr) => {
3117 repr.x.mark_read_only(nodes, &mut output);
3118 repr.weight.mark_read_only(nodes, &mut output);
3119 repr.output_grad.mark_read_only(nodes, &mut output);
3120 }
3121 ModuleOperationIr::Conv2dWeightBackward(repr) => {
3122 repr.x.mark_read_only(nodes, &mut output);
3123 repr.weight.mark_read_only(nodes, &mut output);
3124 repr.output_grad.mark_read_only(nodes, &mut output);
3125 }
3126 ModuleOperationIr::Conv2dBiasBackward(repr) => {
3127 repr.x.mark_read_only(nodes, &mut output);
3128 repr.bias.mark_read_only(nodes, &mut output);
3129 repr.output_grad.mark_read_only(nodes, &mut output);
3130 }
3131 ModuleOperationIr::Conv3d(repr) => {
3132 repr.x.mark_read_only(nodes, &mut output);
3133 repr.weight.mark_read_only(nodes, &mut output);
3134
3135 if let Some(bias) = &mut repr.bias {
3136 bias.mark_read_only(nodes, &mut output);
3137 }
3138 }
3139 ModuleOperationIr::Conv3dXBackward(repr) => {
3140 repr.x.mark_read_only(nodes, &mut output);
3141 repr.weight.mark_read_only(nodes, &mut output);
3142 repr.output_grad.mark_read_only(nodes, &mut output);
3143 }
3144 ModuleOperationIr::Conv3dWeightBackward(repr) => {
3145 repr.x.mark_read_only(nodes, &mut output);
3146 repr.weight.mark_read_only(nodes, &mut output);
3147 repr.output_grad.mark_read_only(nodes, &mut output);
3148 }
3149 ModuleOperationIr::Conv3dBiasBackward(repr) => {
3150 repr.x.mark_read_only(nodes, &mut output);
3151 repr.bias.mark_read_only(nodes, &mut output);
3152 repr.output_grad.mark_read_only(nodes, &mut output);
3153 }
3154 ModuleOperationIr::DeformableConv2d(repr) => {
3155 repr.x.mark_read_only(nodes, &mut output);
3156 repr.weight.mark_read_only(nodes, &mut output);
3157 repr.offset.mark_read_only(nodes, &mut output);
3158
3159 match (&mut repr.mask, &mut repr.bias) {
3160 (Some(mask), Some(bias)) => {
3161 mask.mark_read_only(nodes, &mut output);
3162 bias.mark_read_only(nodes, &mut output);
3163 }
3164 (Some(mask), None) => {
3165 mask.mark_read_only(nodes, &mut output);
3166 }
3167 (None, Some(bias)) => {
3168 bias.mark_read_only(nodes, &mut output);
3169 }
3170 (None, None) => {}
3171 };
3172 }
3173 ModuleOperationIr::DeformableConv2dBackward(repr) => {
3174 repr.x.mark_read_only(nodes, &mut output);
3175 repr.weight.mark_read_only(nodes, &mut output);
3176 repr.offset.mark_read_only(nodes, &mut output);
3177 repr.out_grad.mark_read_only(nodes, &mut output);
3178
3179 if let Some(mask) = repr.mask.as_mut() {
3180 mask.mark_read_only(nodes, &mut output);
3181 }
3182 if let Some(bias) = repr.bias.as_mut() {
3183 bias.mark_read_only(nodes, &mut output);
3184 }
3185 }
3186 ModuleOperationIr::ConvTranspose1d(repr) => {
3187 repr.x.mark_read_only(nodes, &mut output);
3188 repr.weight.mark_read_only(nodes, &mut output);
3189
3190 if let Some(bias) = &mut repr.bias {
3191 bias.mark_read_only(nodes, &mut output);
3192 }
3193 }
3194 ModuleOperationIr::ConvTranspose2d(repr) => {
3195 repr.x.mark_read_only(nodes, &mut output);
3196 repr.weight.mark_read_only(nodes, &mut output);
3197
3198 if let Some(bias) = &mut repr.bias {
3199 bias.mark_read_only(nodes, &mut output);
3200 }
3201 }
3202 ModuleOperationIr::ConvTranspose3d(repr) => {
3203 repr.x.mark_read_only(nodes, &mut output);
3204 repr.weight.mark_read_only(nodes, &mut output);
3205
3206 if let Some(bias) = &mut repr.bias {
3207 bias.mark_read_only(nodes, &mut output);
3208 }
3209 }
3210 ModuleOperationIr::AvgPool1d(repr) => {
3211 repr.x.mark_read_only(nodes, &mut output);
3212 }
3213 ModuleOperationIr::AvgPool2d(repr) => {
3214 repr.x.mark_read_only(nodes, &mut output);
3215 }
3216 ModuleOperationIr::AvgPool1dBackward(repr) => {
3217 repr.x.mark_read_only(nodes, &mut output);
3218 repr.grad.mark_read_only(nodes, &mut output);
3219 }
3220 ModuleOperationIr::AvgPool2dBackward(repr) => {
3221 repr.x.mark_read_only(nodes, &mut output);
3222 repr.grad.mark_read_only(nodes, &mut output);
3223 }
3224 ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
3225 repr.x.mark_read_only(nodes, &mut output);
3226 }
3227 ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
3228 repr.x.mark_read_only(nodes, &mut output);
3229 }
3230 ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
3231 repr.x.mark_read_only(nodes, &mut output);
3232 repr.grad.mark_read_only(nodes, &mut output);
3233 }
3234 ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
3235 repr.x.mark_read_only(nodes, &mut output);
3236 repr.grad.mark_read_only(nodes, &mut output);
3237 }
3238 ModuleOperationIr::MaxPool1d(repr) => {
3239 repr.x.mark_read_only(nodes, &mut output);
3240 }
3241 ModuleOperationIr::MaxPool1dWithIndices(repr) => {
3242 repr.x.mark_read_only(nodes, &mut output);
3243 }
3244 ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
3245 repr.x.mark_read_only(nodes, &mut output);
3246 repr.grad.mark_read_only(nodes, &mut output);
3247 }
3248 ModuleOperationIr::MaxPool2d(repr) => {
3249 repr.x.mark_read_only(nodes, &mut output);
3250 }
3251 ModuleOperationIr::MaxPool2dWithIndices(repr) => {
3252 repr.x.mark_read_only(nodes, &mut output);
3253 }
3254 ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
3255 repr.x.mark_read_only(nodes, &mut output);
3256 repr.grad.mark_read_only(nodes, &mut output);
3257 }
3258 ModuleOperationIr::Interpolate(repr) => {
3259 repr.x.mark_read_only(nodes, &mut output);
3260 }
3261 ModuleOperationIr::InterpolateBackward(repr) => {
3262 repr.x.mark_read_only(nodes, &mut output);
3263 repr.grad.mark_read_only(nodes, &mut output);
3264 }
3265 ModuleOperationIr::Rfft(repr) => {
3266 repr.signal.mark_read_only(nodes, &mut output);
3267 }
3268 ModuleOperationIr::IRfft(repr) => {
3269 repr.input_re.mark_read_only(nodes, &mut output);
3270 repr.input_im.mark_read_only(nodes, &mut output);
3271 }
3272 ModuleOperationIr::Attention(repr) => {
3273 repr.query.mark_read_only(nodes, &mut output);
3274 repr.key.mark_read_only(nodes, &mut output);
3275 repr.value.mark_read_only(nodes, &mut output);
3276 if let Some(mask) = &mut repr.mask {
3277 mask.mark_read_only(nodes, &mut output);
3278 }
3279 if let Some(attn_bias) = &mut repr.attn_bias {
3280 attn_bias.mark_read_only(nodes, &mut output);
3281 }
3282 }
3283 ModuleOperationIr::CtcLoss(repr) => {
3284 repr.log_probs.mark_read_only(nodes, &mut output);
3285 repr.targets.mark_read_only(nodes, &mut output);
3286 repr.input_lengths.mark_read_only(nodes, &mut output);
3287 repr.target_lengths.mark_read_only(nodes, &mut output);
3288 }
3289 ModuleOperationIr::CtcLossBackward(repr) => {
3290 repr.log_probs.mark_read_only(nodes, &mut output);
3291 repr.targets.mark_read_only(nodes, &mut output);
3292 repr.input_lengths.mark_read_only(nodes, &mut output);
3293 repr.target_lengths.mark_read_only(nodes, &mut output);
3294 repr.grad_loss.mark_read_only(nodes, &mut output);
3295 }
3296 };
3297
3298 output
3299 }
3300}
3301
3302#[cfg(feature = "distributed")]
3303impl DistributedOperationIr {
3304 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3305 match self {
3306 DistributedOperationIr::AllReduce(repr) => Box::new([&repr.tensor].into_iter()),
3307 }
3308 }
3309
3310 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3311 match self {
3312 DistributedOperationIr::AllReduce(repr) => Box::new([&repr.out].into_iter()),
3313 }
3314 }
3315
3316 fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
3317 let mut output = Vec::new();
3318
3319 match self {
3320 DistributedOperationIr::AllReduce(repr) => {
3321 repr.tensor.mark_read_only(nodes, &mut output);
3322 }
3323 }
3324
3325 output
3326 }
3327}
3328
3329impl InitOperationIr {
3330 fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3331 Box::new([].into_iter())
3332 }
3333 fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3334 Box::new([&self.out].into_iter())
3335 }
3336}
3337
3338impl TensorIr {
3339 fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {
3340 if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {
3341 output.push(self.clone());
3342 self.status = TensorStatus::ReadOnly;
3343 }
3344 }
3345}
3346
3347impl core::hash::Hash for RandomOpIr {
3348 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
3349 self.out.hash(state);
3350
3351 match self.distribution {
3352 Distribution::Default => 1u8.hash(state),
3353 Distribution::Bernoulli(_) => 2u8.hash(state),
3354 Distribution::Uniform(_, _) => 3u8.hash(state),
3355 Distribution::Normal(_, _) => 4u8.hash(state),
3356 }
3357 }
3358}
3359
3360pub trait OperationOutput<O> {
3362 fn output(self) -> O;
3364
3365 fn outputs<const N: usize>(self) -> [O; N];
3367}
3368
3369impl<O: core::fmt::Debug> OperationOutput<O> for Vec<O> {
3370 fn output(self) -> O {
3371 let [tensor] = self.outputs();
3372 tensor
3373 }
3374
3375 fn outputs<const N: usize>(self) -> [O; N] {
3376 self.try_into().unwrap()
3377 }
3378}