1use core::hash::Hash;
2use core::ops::Range;
3use serde::{Deserialize, Serialize};
4
5use alloc::borrow::ToOwned;
6use alloc::boxed::Box;
7use alloc::{string::String, vec, vec::Vec};
8
9use crate::{
10 ops::{
11 ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
12 },
13 quantization::QuantizationScheme,
14 repr::tensor::TensorDescription,
15 DType, Distribution, Element,
16};
17
18#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
20pub struct CustomOpDescription {
21 pub id: String,
23 pub inputs: Vec<TensorDescription>,
25 pub outputs: Vec<TensorDescription>,
27}
28
29impl CustomOpDescription {
30 pub fn new(
32 id: &'static str,
33 inputs: &[TensorDescription],
34 outputs: &[TensorDescription],
35 ) -> Self {
36 Self {
37 id: id.to_owned(),
38 inputs: inputs.to_vec(),
39 outputs: outputs.to_vec(),
40 }
41 }
42
43 pub fn consume<const N_IN: usize, const N_OUT: usize>(
45 self,
46 ) -> ([TensorDescription; N_IN], [TensorDescription; N_OUT]) {
47 (
48 self.inputs.try_into().expect(
49 "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
50 ),
51 self.outputs.try_into().expect(
52 "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
53 ),
54 )
55 }
56
57 fn nodes(&self) -> Vec<&TensorDescription> {
58 self.inputs.iter().chain(self.outputs.iter()).collect()
59 }
60}
61
62#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
64pub enum OperationDescription {
65 BaseFloat(BaseOperationDescription),
67 BaseInt(BaseOperationDescription),
69 BaseBool(BaseOperationDescription),
71 NumericFloat(DType, NumericOperationDescription<f32>),
73 NumericInt(DType, NumericOperationDescription<i32>),
75 Bool(BoolOperationDescription),
77 Int(IntOperationDescription),
79 Float(DType, FloatOperationDescription),
81 Module(ModuleOperationDescription),
83 Custom(CustomOpDescription),
85}
86
87#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
89pub enum FloatOperationDescription {
90 Exp(UnaryOperationDescription),
92 Log(UnaryOperationDescription),
94 Log1p(UnaryOperationDescription),
96 Erf(UnaryOperationDescription),
98 PowfScalar(ScalarOperationDescription<f32>),
100 Sqrt(UnaryOperationDescription),
102 Cos(UnaryOperationDescription),
104 Sin(UnaryOperationDescription),
106 Tanh(UnaryOperationDescription),
108 Round(UnaryOperationDescription),
110 Floor(UnaryOperationDescription),
112 Ceil(UnaryOperationDescription),
114 IntoInt(UnaryOperationDescription),
116 Matmul(BinaryOperationDescription),
118 Random(RandomOperationDescription),
120 Recip(UnaryOperationDescription),
122 Quantize(QuantizeOperationDescription),
124 Dequantize(DequantizeOperationDescription),
126}
127
128#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
130pub enum ModuleOperationDescription {
131 Embedding(EmbeddingDescription),
133 EmbeddingBackward(EmbeddingBackwardDescription),
135 Conv1d(Conv1dDescription),
137 Conv2d(Conv2dDescription),
139 Conv3d(Conv3dDescription),
141 DeformableConv2d(Box<DeformConv2dDescription>),
143 DeformableConv2dBackward(Box<DeformConv2dBackwardDescription>),
145 ConvTranspose1d(ConvTranspose1dDescription),
147 ConvTranspose2d(ConvTranspose2dDescription),
149 ConvTranspose3d(ConvTranspose3dDescription),
151 AvgPool1d(AvgPool1dDescription),
153 AvgPool2d(AvgPool2dDescription),
155 AvgPool1dBackward(AvgPool1dBackwardDescription),
158 AvgPool2dBackward(AvgPool2dBackwardDescription),
161 AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription),
164 AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription),
167 AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardDescription),
170 AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardDescription),
173 MaxPool1d(MaxPool1dDescription),
176 MaxPool1dWithIndices(MaxPool1dWithIndicesDescription),
179 MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardDescription),
182 MaxPool2d(MaxPool2dDescription),
185 MaxPool2dWithIndices(MaxPool2dWithIndicesDescription),
188 MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription),
191 Interpolate(InterpolateDescription),
193 InterpolateBackward(InterpolateBackwardDescription),
195}
196
197#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
199pub enum BaseOperationDescription {
200 ToDevice(TensorDescription),
206 Reshape(ReshapeDescription),
212
213 SwapDims(SwapDimsDescription),
219
220 Permute(PermuteOperationDescription),
226
227 Flip(FlipOperationDescription),
232
233 Expand(ExpandOperationDescription),
239
240 Slice(SliceOperationDescription),
246 SliceAssign(SliceAssignOperationDescription),
252 Equal(BinaryOperationDescription),
258 RepeatDim(RepeatDimOperationDescription),
264 Cat(CatOperationDescription),
270 Cast(UnaryOperationDescription),
272
273 Empty(TensorDescription),
279}
280
281#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
283pub enum NumericOperationDescription<E> {
284 Add(BinaryOperationDescription),
289 AddScalar(ScalarOperationDescription<E>),
294 Sub(BinaryOperationDescription),
299 SubScalar(ScalarOperationDescription<E>),
304 Div(BinaryOperationDescription),
309 DivScalar(ScalarOperationDescription<E>),
314 Rem(BinaryOperationDescription),
319 RemScalar(ScalarOperationDescription<E>),
324 Mul(BinaryOperationDescription),
329 MulScalar(ScalarOperationDescription<E>),
334 Abs(UnaryOperationDescription),
339 Ones(TensorDescription),
344 Zeros(TensorDescription),
349 Full((TensorDescription, E)),
354 Gather(GatherOperationDescription),
359 Scatter(ScatterOperationDescription),
364 Select(SelectOperationDescription),
369 SelectAssign(SelectAssignOperationDescription),
374 MaskWhere(MaskWhereOperationDescription),
379 MaskFill(MaskFillOperationDescription<E>),
384 MeanDim(ScalarOperationDescription<usize>),
389 Mean(UnaryOperationDescription),
394 Sum(UnaryOperationDescription),
399 SumDim(ScalarOperationDescription<usize>),
404
405 Prod(UnaryOperationDescription),
410
411 ProdDim(ScalarOperationDescription<usize>),
416
417 EqualElem(ScalarOperationDescription<E>),
422 Greater(BinaryOperationDescription),
427 GreaterElem(ScalarOperationDescription<E>),
432 GreaterEqual(BinaryOperationDescription),
437 GreaterEqualElem(ScalarOperationDescription<E>),
442 Lower(BinaryOperationDescription),
447 LowerElem(ScalarOperationDescription<E>),
452 LowerEqual(BinaryOperationDescription),
457 LowerEqualElem(ScalarOperationDescription<E>),
462 ArgMax(ScalarOperationDescription<usize>),
467 ArgMin(ScalarOperationDescription<usize>),
472 Max(UnaryOperationDescription),
477 MaxDimWithIndices(ReduceDimWithIndicesDescription),
482 MinDimWithIndices(ReduceDimWithIndicesDescription),
487 Min(UnaryOperationDescription),
492 MaxDim(ScalarOperationDescription<usize>),
497 MinDim(ScalarOperationDescription<usize>),
502 Clamp(ClampOperationDescription<E>),
507 IntRandom(RandomOperationDescription),
511 Powf(BinaryOperationDescription),
516}
517
518#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
520pub enum IntOperationDescription {
521 IntoFloat(UnaryOperationDescription),
523}
524
525#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
527pub enum BoolOperationDescription {
528 IntoFloat(UnaryOperationDescription),
530 IntoInt(UnaryOperationDescription),
532 Not(UnaryOperationDescription),
534}
535
536#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
538pub struct SwapDimsDescription {
539 pub input: TensorDescription,
541 pub out: TensorDescription,
543 pub dim1: usize,
545 pub dim2: usize,
547}
548
549#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
551pub struct PermuteOperationDescription {
552 pub input: TensorDescription,
554 pub out: TensorDescription,
556 pub axes: Vec<usize>,
558}
559
560#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
562pub struct ExpandOperationDescription {
563 pub input: TensorDescription,
565 pub out: TensorDescription,
567 pub shape: Vec<usize>,
569}
570
571#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
573pub struct FlipOperationDescription {
574 pub input: TensorDescription,
576 pub out: TensorDescription,
578 pub axes: Vec<usize>,
580}
581
582#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
583#[allow(missing_docs)]
584pub struct RandomOperationDescription {
585 pub out: TensorDescription,
586 pub distribution: Distribution,
587}
588
589#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
590#[allow(missing_docs)]
591pub struct ReshapeDescription {
592 pub input: TensorDescription,
593 pub out: TensorDescription,
594}
595
596#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
597#[allow(missing_docs)]
598pub struct ExpandDescription {
599 pub input: TensorDescription,
600 pub out: TensorDescription,
601}
602
603#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
604#[allow(missing_docs)]
605pub struct BinaryOperationDescription {
606 pub lhs: TensorDescription,
607 pub rhs: TensorDescription,
608 pub out: TensorDescription,
609}
610
611#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
612#[allow(missing_docs)]
613pub struct UnaryOperationDescription {
614 pub input: TensorDescription,
615 pub out: TensorDescription,
616}
617
618#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
619#[allow(missing_docs)]
620pub struct ScalarOperationDescription<E> {
621 pub lhs: TensorDescription,
622 pub rhs: E,
623 pub out: TensorDescription,
624}
625
626#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
627#[allow(missing_docs)]
628pub struct GatherOperationDescription {
629 pub tensor: TensorDescription,
630 pub dim: usize,
631 pub indices: TensorDescription,
632 pub out: TensorDescription,
633}
634
635#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
636#[allow(missing_docs)]
637pub struct ScatterOperationDescription {
638 pub tensor: TensorDescription,
639 pub dim: usize,
640 pub indices: TensorDescription,
641 pub value: TensorDescription,
642 pub out: TensorDescription,
643}
644
645#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
646#[allow(missing_docs)]
647pub struct SelectOperationDescription {
648 pub tensor: TensorDescription,
649 pub dim: usize,
650 pub indices: TensorDescription,
651 pub out: TensorDescription,
652}
653
654#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
655#[allow(missing_docs)]
656pub struct SelectAssignOperationDescription {
657 pub tensor: TensorDescription,
658 pub dim: usize,
659 pub indices: TensorDescription,
660 pub value: TensorDescription,
661 pub out: TensorDescription,
662}
663
664#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
665#[allow(missing_docs)]
666pub struct SliceOperationDescription {
667 pub tensor: TensorDescription,
668 pub ranges: Vec<Range<usize>>,
669 pub out: TensorDescription,
670}
671
672#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
673#[allow(missing_docs)]
674pub struct SliceAssignOperationDescription {
675 pub tensor: TensorDescription,
676 pub ranges: Vec<Range<usize>>,
677 pub value: TensorDescription,
678 pub out: TensorDescription,
679}
680
681#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
682#[allow(missing_docs)]
683pub struct MaskWhereOperationDescription {
684 pub tensor: TensorDescription,
685 pub mask: TensorDescription,
686 pub value: TensorDescription,
687 pub out: TensorDescription,
688}
689
690#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
691#[allow(missing_docs)]
692pub struct MaskFillOperationDescription<E> {
693 pub tensor: TensorDescription,
694 pub mask: TensorDescription,
695 pub value: E,
696 pub out: TensorDescription,
697}
698
699#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
700#[allow(missing_docs)]
701pub struct ClampOperationDescription<E> {
702 pub tensor: TensorDescription,
703 pub min: E,
704 pub max: E,
705 pub out: TensorDescription,
706}
707
708#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
709#[allow(missing_docs)]
710pub struct RepeatDimOperationDescription {
711 pub tensor: TensorDescription,
712 pub dim: usize,
713 pub times: usize,
714 pub out: TensorDescription,
715}
716
717#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
718#[allow(missing_docs)]
719pub struct CatOperationDescription {
720 pub tensors: Vec<TensorDescription>,
721 pub dim: usize,
722 pub out: TensorDescription,
723}
724
725#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
726#[allow(missing_docs)]
727pub struct ReduceDimWithIndicesDescription {
728 pub tensor: TensorDescription,
729 pub dim: usize,
730 pub out: TensorDescription,
731 pub out_indices: TensorDescription,
732}
733
734#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
735#[allow(missing_docs)]
736pub struct EmbeddingDescription {
737 pub weights: TensorDescription,
738 pub indices: TensorDescription,
739 pub out: TensorDescription,
740}
741
742#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
743#[allow(missing_docs)]
744pub struct EmbeddingBackwardDescription {
745 pub weights: TensorDescription,
746 pub out_grad: TensorDescription,
747 pub indices: TensorDescription,
748 pub out: TensorDescription,
749}
750
751#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
752#[allow(missing_docs)]
753pub struct Conv1dDescription {
754 pub x: TensorDescription,
755 pub weight: TensorDescription,
756 pub bias: Option<TensorDescription>,
757 pub options: Conv1dOptionsDescription,
758 pub out: TensorDescription,
759}
760
761#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
762#[allow(missing_docs)]
763pub struct Conv2dDescription {
764 pub x: TensorDescription,
765 pub weight: TensorDescription,
766 pub bias: Option<TensorDescription>,
767 pub options: Conv2dOptionsDescription,
768 pub out: TensorDescription,
769}
770
771#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
772#[allow(missing_docs)]
773pub struct DeformConv2dDescription {
774 pub x: TensorDescription,
775 pub offset: TensorDescription,
776 pub weight: TensorDescription,
777 pub mask: Option<TensorDescription>,
778 pub bias: Option<TensorDescription>,
779 pub options: DeformableConv2dOptionsDescription,
780 pub out: TensorDescription,
781}
782
783#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
784#[allow(missing_docs)]
785pub struct DeformConv2dBackwardDescription {
786 pub x: TensorDescription,
787 pub offset: TensorDescription,
788 pub weight: TensorDescription,
789 pub mask: Option<TensorDescription>,
790 pub bias: Option<TensorDescription>,
791 pub out_grad: TensorDescription,
792 pub options: DeformableConv2dOptionsDescription,
793 pub input_grad: TensorDescription,
794 pub offset_grad: TensorDescription,
795 pub weight_grad: TensorDescription,
796 pub mask_grad: Option<TensorDescription>,
797 pub bias_grad: Option<TensorDescription>,
798}
799
800#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
801#[allow(missing_docs)]
802pub struct Conv3dDescription {
803 pub x: TensorDescription,
804 pub weight: TensorDescription,
805 pub bias: Option<TensorDescription>,
806 pub options: Conv3dOptionsDescription,
807 pub out: TensorDescription,
808}
809
810#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
811#[allow(missing_docs)]
812pub struct ConvTranspose1dDescription {
813 pub x: TensorDescription,
814 pub weight: TensorDescription,
815 pub bias: Option<TensorDescription>,
816 pub options: ConvTranspose1dOptionsDescription,
817 pub out: TensorDescription,
818}
819
820#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
821#[allow(missing_docs)]
822pub struct ConvTranspose2dDescription {
823 pub x: TensorDescription,
824 pub weight: TensorDescription,
825 pub bias: Option<TensorDescription>,
826 pub options: ConvTranspose2dOptionsDescription,
827 pub out: TensorDescription,
828}
829
830#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
831#[allow(missing_docs)]
832pub struct ConvTranspose3dDescription {
833 pub x: TensorDescription,
834 pub weight: TensorDescription,
835 pub bias: Option<TensorDescription>,
836 pub options: ConvTranspose3dOptionsDescription,
837 pub out: TensorDescription,
838}
839
840#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
841#[allow(missing_docs)]
842pub struct Conv1dOptionsDescription {
843 pub stride: [usize; 1],
844 pub padding: [usize; 1],
845 pub dilation: [usize; 1],
846 pub groups: usize,
847}
848
849#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
850#[allow(missing_docs)]
851pub struct Conv2dOptionsDescription {
852 pub stride: [usize; 2],
853 pub padding: [usize; 2],
854 pub dilation: [usize; 2],
855 pub groups: usize,
856}
857
858#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
859#[allow(missing_docs)]
860pub struct DeformableConv2dOptionsDescription {
861 pub stride: [usize; 2],
862 pub padding: [usize; 2],
863 pub dilation: [usize; 2],
864 pub weight_groups: usize,
865 pub offset_groups: usize,
866}
867
868#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
869#[allow(missing_docs)]
870pub struct Conv3dOptionsDescription {
871 pub stride: [usize; 3],
872 pub padding: [usize; 3],
873 pub dilation: [usize; 3],
874 pub groups: usize,
875}
876
877#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
878#[allow(missing_docs)]
879pub struct ConvTranspose1dOptionsDescription {
880 pub stride: [usize; 1],
881 pub padding: [usize; 1],
882 pub padding_out: [usize; 1],
883 pub dilation: [usize; 1],
884 pub groups: usize,
885}
886
887#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
888#[allow(missing_docs)]
889pub struct ConvTranspose2dOptionsDescription {
890 pub stride: [usize; 2],
891 pub padding: [usize; 2],
892 pub padding_out: [usize; 2],
893 pub dilation: [usize; 2],
894 pub groups: usize,
895}
896
897#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
898#[allow(missing_docs)]
899pub struct ConvTranspose3dOptionsDescription {
900 pub stride: [usize; 3],
901 pub padding: [usize; 3],
902 pub padding_out: [usize; 3],
903 pub dilation: [usize; 3],
904 pub groups: usize,
905}
906
907#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
909pub struct QuantizationParametersDescription {
910 pub scale: TensorDescription,
912 pub offset: Option<TensorDescription>,
914}
915
916#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
917#[allow(missing_docs)]
918pub struct QuantizeOperationDescription {
919 pub tensor: TensorDescription,
920 pub qparams: QuantizationParametersDescription,
921 pub scheme: QuantizationScheme,
922 pub out: TensorDescription,
923}
924
925#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
926#[allow(missing_docs)]
927pub struct DequantizeOperationDescription {
928 pub input: TensorDescription,
929 pub out: TensorDescription,
930}
931
932impl From<ConvOptions<1>> for Conv1dOptionsDescription {
933 fn from(value: ConvOptions<1>) -> Self {
934 Self {
935 stride: value.stride,
936 padding: value.padding,
937 dilation: value.dilation,
938 groups: value.groups,
939 }
940 }
941}
942
943impl From<ConvOptions<2>> for Conv2dOptionsDescription {
944 fn from(value: ConvOptions<2>) -> Self {
945 Self {
946 stride: value.stride,
947 padding: value.padding,
948 dilation: value.dilation,
949 groups: value.groups,
950 }
951 }
952}
953
954impl From<ConvOptions<3>> for Conv3dOptionsDescription {
955 fn from(value: ConvOptions<3>) -> Self {
956 Self {
957 stride: value.stride,
958 padding: value.padding,
959 dilation: value.dilation,
960 groups: value.groups,
961 }
962 }
963}
964
965impl From<DeformConvOptions<2>> for DeformableConv2dOptionsDescription {
966 fn from(value: DeformConvOptions<2>) -> Self {
967 Self {
968 stride: value.stride,
969 padding: value.padding,
970 dilation: value.dilation,
971 weight_groups: value.weight_groups,
972 offset_groups: value.offset_groups,
973 }
974 }
975}
976
977impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsDescription {
978 fn from(value: ConvTransposeOptions<1>) -> Self {
979 Self {
980 stride: value.stride,
981 padding: value.padding,
982 padding_out: value.padding_out,
983 dilation: value.dilation,
984 groups: value.groups,
985 }
986 }
987}
988
989impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsDescription {
990 fn from(value: ConvTransposeOptions<2>) -> Self {
991 Self {
992 stride: value.stride,
993 padding: value.padding,
994 padding_out: value.padding_out,
995 dilation: value.dilation,
996 groups: value.groups,
997 }
998 }
999}
1000
1001impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsDescription {
1002 fn from(value: ConvTransposeOptions<3>) -> Self {
1003 Self {
1004 stride: value.stride,
1005 padding: value.padding,
1006 padding_out: value.padding_out,
1007 dilation: value.dilation,
1008 groups: value.groups,
1009 }
1010 }
1011}
1012
1013impl From<Conv1dOptionsDescription> for ConvOptions<1> {
1014 fn from(val: Conv1dOptionsDescription) -> Self {
1015 ConvOptions {
1016 stride: val.stride,
1017 padding: val.padding,
1018 dilation: val.dilation,
1019 groups: val.groups,
1020 }
1021 }
1022}
1023
1024impl From<Conv2dOptionsDescription> for ConvOptions<2> {
1025 fn from(val: Conv2dOptionsDescription) -> Self {
1026 ConvOptions {
1027 stride: val.stride,
1028 padding: val.padding,
1029 dilation: val.dilation,
1030 groups: val.groups,
1031 }
1032 }
1033}
1034
1035impl From<Conv3dOptionsDescription> for ConvOptions<3> {
1036 fn from(val: Conv3dOptionsDescription) -> Self {
1037 ConvOptions {
1038 stride: val.stride,
1039 padding: val.padding,
1040 dilation: val.dilation,
1041 groups: val.groups,
1042 }
1043 }
1044}
1045
1046impl From<DeformableConv2dOptionsDescription> for DeformConvOptions<2> {
1047 fn from(value: DeformableConv2dOptionsDescription) -> Self {
1048 DeformConvOptions {
1049 stride: value.stride,
1050 padding: value.padding,
1051 dilation: value.dilation,
1052 weight_groups: value.weight_groups,
1053 offset_groups: value.offset_groups,
1054 }
1055 }
1056}
1057
1058impl From<ConvTranspose1dOptionsDescription> for ConvTransposeOptions<1> {
1059 fn from(val: ConvTranspose1dOptionsDescription) -> Self {
1060 ConvTransposeOptions {
1061 stride: val.stride,
1062 padding: val.padding,
1063 padding_out: val.padding_out,
1064 dilation: val.dilation,
1065 groups: val.groups,
1066 }
1067 }
1068}
1069
1070impl From<ConvTranspose2dOptionsDescription> for ConvTransposeOptions<2> {
1071 fn from(val: ConvTranspose2dOptionsDescription) -> Self {
1072 ConvTransposeOptions {
1073 stride: val.stride,
1074 padding: val.padding,
1075 padding_out: val.padding_out,
1076 dilation: val.dilation,
1077 groups: val.groups,
1078 }
1079 }
1080}
1081
1082impl From<ConvTranspose3dOptionsDescription> for ConvTransposeOptions<3> {
1083 fn from(val: ConvTranspose3dOptionsDescription) -> Self {
1084 ConvTransposeOptions {
1085 stride: val.stride,
1086 padding: val.padding,
1087 padding_out: val.padding_out,
1088 dilation: val.dilation,
1089 groups: val.groups,
1090 }
1091 }
1092}
1093
1094#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1095#[allow(missing_docs)]
1096pub struct AvgPool1dDescription {
1097 pub x: TensorDescription,
1098 pub kernel_size: usize,
1099 pub stride: usize,
1100 pub padding: usize,
1101 pub count_include_pad: bool,
1102 pub out: TensorDescription,
1103}
1104
1105#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1106#[allow(missing_docs)]
1107pub struct AvgPool2dDescription {
1108 pub x: TensorDescription,
1109 pub kernel_size: [usize; 2],
1110 pub stride: [usize; 2],
1111 pub padding: [usize; 2],
1112 pub count_include_pad: bool,
1113 pub out: TensorDescription,
1114}
1115
1116#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1117#[allow(missing_docs)]
1118pub struct AvgPool1dBackwardDescription {
1119 pub x: TensorDescription,
1120 pub grad: TensorDescription,
1121 pub kernel_size: usize,
1122 pub stride: usize,
1123 pub padding: usize,
1124 pub count_include_pad: bool,
1125 pub out: TensorDescription,
1126}
1127
1128#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1129#[allow(missing_docs)]
1130pub struct AvgPool2dBackwardDescription {
1131 pub x: TensorDescription,
1132 pub grad: TensorDescription,
1133 pub kernel_size: [usize; 2],
1134 pub stride: [usize; 2],
1135 pub padding: [usize; 2],
1136 pub count_include_pad: bool,
1137 pub out: TensorDescription,
1138}
1139
1140#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1141#[allow(missing_docs)]
1142pub struct AdaptiveAvgPool1dDescription {
1143 pub x: TensorDescription,
1144 pub output_size: usize,
1145 pub out: TensorDescription,
1146}
1147
1148#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1149#[allow(missing_docs)]
1150pub struct AdaptiveAvgPool2dDescription {
1151 pub x: TensorDescription,
1152 pub output_size: [usize; 2],
1153 pub out: TensorDescription,
1154}
1155
1156#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1157#[allow(missing_docs)]
1158pub struct AdaptiveAvgPool1dBackwardDescription {
1159 pub x: TensorDescription,
1160 pub grad: TensorDescription,
1161 pub out: TensorDescription,
1162}
1163
1164#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1165#[allow(missing_docs)]
1166pub struct AdaptiveAvgPool2dBackwardDescription {
1167 pub x: TensorDescription,
1168 pub grad: TensorDescription,
1169 pub out: TensorDescription,
1170}
1171
1172#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1173#[allow(missing_docs)]
1174pub struct MaxPool1dDescription {
1175 pub x: TensorDescription,
1176 pub kernel_size: usize,
1177 pub stride: usize,
1178 pub padding: usize,
1179 pub dilation: usize,
1180 pub out: TensorDescription,
1181}
1182
1183#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1184#[allow(missing_docs)]
1185pub struct MaxPool1dWithIndicesDescription {
1186 pub x: TensorDescription,
1187 pub kernel_size: usize,
1188 pub stride: usize,
1189 pub padding: usize,
1190 pub dilation: usize,
1191 pub out: TensorDescription,
1192 pub out_indices: TensorDescription,
1193}
1194
1195#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1196#[allow(missing_docs)]
1197pub struct MaxPool1dWithIndicesBackwardDescription {
1198 pub x: TensorDescription,
1199 pub grad: TensorDescription,
1200 pub indices: TensorDescription,
1201 pub kernel_size: usize,
1202 pub stride: usize,
1203 pub padding: usize,
1204 pub dilation: usize,
1205 pub out: TensorDescription,
1206}
1207
1208#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1209#[allow(missing_docs)]
1210pub struct MaxPool2dDescription {
1211 pub x: TensorDescription,
1212 pub kernel_size: [usize; 2],
1213 pub stride: [usize; 2],
1214 pub padding: [usize; 2],
1215 pub dilation: [usize; 2],
1216 pub out: TensorDescription,
1217}
1218
1219#[allow(missing_docs)]
1220#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1221pub struct MaxPool2dWithIndicesDescription {
1222 pub x: TensorDescription,
1223 pub kernel_size: [usize; 2],
1224 pub stride: [usize; 2],
1225 pub padding: [usize; 2],
1226 pub dilation: [usize; 2],
1227 pub out: TensorDescription,
1228 pub out_indices: TensorDescription,
1229}
1230
1231#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1232#[allow(missing_docs)]
1233pub struct MaxPool2dWithIndicesBackwardDescription {
1234 pub x: TensorDescription,
1235 pub grad: TensorDescription,
1236 pub indices: TensorDescription,
1237 pub kernel_size: [usize; 2],
1238 pub stride: [usize; 2],
1239 pub padding: [usize; 2],
1240 pub dilation: [usize; 2],
1241 pub out: TensorDescription,
1242}
1243
1244#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1245#[allow(missing_docs)]
1246pub enum InterpolateModeDescription {
1247 Nearest,
1248 Bilinear,
1249 Bicubic,
1250}
1251
1252#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1253#[allow(missing_docs)]
1254pub struct InterpolateOptionsDescription {
1255 pub mode: InterpolateModeDescription,
1256}
1257
1258#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1259#[allow(missing_docs)]
1260pub struct InterpolateDescription {
1261 pub x: TensorDescription,
1262 pub output_size: [usize; 2],
1263 pub options: InterpolateOptionsDescription,
1264 pub out: TensorDescription,
1265}
1266
1267impl From<InterpolateModeDescription> for InterpolateMode {
1268 fn from(val: InterpolateModeDescription) -> Self {
1269 match val {
1270 InterpolateModeDescription::Nearest => Self::Nearest,
1271 InterpolateModeDescription::Bilinear => Self::Bilinear,
1272 InterpolateModeDescription::Bicubic => Self::Bicubic,
1273 }
1274 }
1275}
1276
1277impl From<InterpolateOptionsDescription> for InterpolateOptions {
1278 fn from(val: InterpolateOptionsDescription) -> Self {
1279 Self {
1280 mode: val.mode.into(),
1281 }
1282 }
1283}
1284
1285impl From<InterpolateMode> for InterpolateModeDescription {
1286 fn from(val: InterpolateMode) -> Self {
1287 match val {
1288 InterpolateMode::Nearest => Self::Nearest,
1289 InterpolateMode::Bilinear => Self::Bilinear,
1290 InterpolateMode::Bicubic => Self::Bicubic,
1291 }
1292 }
1293}
1294
1295impl From<InterpolateOptions> for InterpolateOptionsDescription {
1296 fn from(val: InterpolateOptions) -> Self {
1297 Self {
1298 mode: val.mode.into(),
1299 }
1300 }
1301}
1302
1303#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1304#[allow(missing_docs)]
1305pub struct InterpolateBackwardDescription {
1306 pub x: TensorDescription,
1307 pub grad: TensorDescription,
1308 pub output_size: [usize; 2],
1309 pub options: InterpolateOptionsDescription,
1310 pub out: TensorDescription,
1311}
1312
1313impl OperationDescription {
1314 pub fn nodes(&self) -> Vec<&TensorDescription> {
1316 match self {
1317 OperationDescription::BaseFloat(ops) => ops.nodes(),
1318 OperationDescription::BaseInt(ops) => ops.nodes(),
1319 OperationDescription::BaseBool(ops) => ops.nodes(),
1320 OperationDescription::NumericFloat(_dtype, ops) => ops.nodes(),
1321 OperationDescription::NumericInt(_dtype, ops) => ops.nodes(),
1322 OperationDescription::Bool(ops) => ops.nodes(),
1323 OperationDescription::Int(ops) => ops.nodes(),
1324 OperationDescription::Float(_dtype, ops) => ops.nodes(),
1325 OperationDescription::Module(ops) => ops.nodes(),
1326 OperationDescription::Custom(ops) => ops.nodes(),
1327 }
1328 }
1329}
1330
1331impl BaseOperationDescription {
1332 fn nodes(&self) -> Vec<&TensorDescription> {
1333 match self {
1334 BaseOperationDescription::ToDevice(desc) => vec![desc],
1335 BaseOperationDescription::Reshape(desc) => {
1336 vec![&desc.input, &desc.out]
1337 }
1338 BaseOperationDescription::SwapDims(desc) => {
1339 vec![&desc.input, &desc.out]
1340 }
1341 BaseOperationDescription::Permute(desc) => {
1342 vec![&desc.input, &desc.out]
1343 }
1344
1345 BaseOperationDescription::Expand(desc) => {
1346 vec![&desc.input, &desc.out]
1347 }
1348
1349 BaseOperationDescription::Flip(desc) => {
1350 vec![&desc.input, &desc.out]
1351 }
1352 BaseOperationDescription::Slice(desc) => {
1353 vec![&desc.tensor, &desc.out]
1354 }
1355 BaseOperationDescription::SliceAssign(desc) => {
1356 vec![&desc.tensor, &desc.value, &desc.out]
1357 }
1358 BaseOperationDescription::Equal(desc) => {
1359 vec![&desc.lhs, &desc.rhs, &desc.out]
1360 }
1361 BaseOperationDescription::RepeatDim(desc) => {
1362 vec![&desc.tensor, &desc.out]
1363 }
1364 BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
1365 BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out],
1366 BaseOperationDescription::Empty(desc) => vec![desc],
1367 }
1368 }
1369}
1370
1371impl<E: Element> NumericOperationDescription<E> {
1372 fn nodes(&self) -> Vec<&TensorDescription> {
1373 match self {
1374 NumericOperationDescription::Add(desc) => {
1375 vec![&desc.lhs, &desc.rhs, &desc.out]
1376 }
1377 NumericOperationDescription::AddScalar(desc) => {
1378 vec![&desc.lhs, &desc.out]
1379 }
1380 NumericOperationDescription::Sub(desc) => {
1381 vec![&desc.lhs, &desc.rhs, &desc.out]
1382 }
1383 NumericOperationDescription::SubScalar(desc) => {
1384 vec![&desc.lhs, &desc.out]
1385 }
1386 NumericOperationDescription::Mul(desc) => {
1387 vec![&desc.lhs, &desc.rhs, &desc.out]
1388 }
1389 NumericOperationDescription::MulScalar(desc) => {
1390 vec![&desc.lhs, &desc.out]
1391 }
1392 NumericOperationDescription::Div(desc) => {
1393 vec![&desc.lhs, &desc.rhs, &desc.out]
1394 }
1395 NumericOperationDescription::DivScalar(desc) => {
1396 vec![&desc.lhs, &desc.out]
1397 }
1398 NumericOperationDescription::Rem(desc) => {
1399 vec![&desc.lhs, &desc.rhs, &desc.out]
1400 }
1401 NumericOperationDescription::RemScalar(desc) => {
1402 vec![&desc.lhs, &desc.out]
1403 }
1404 NumericOperationDescription::Ones(desc) => vec![desc],
1405 NumericOperationDescription::Gather(desc) => {
1406 vec![&desc.tensor, &desc.indices, &desc.out]
1407 }
1408 NumericOperationDescription::Scatter(desc) => {
1409 vec![&desc.tensor, &desc.indices, &desc.value, &desc.out]
1410 }
1411 NumericOperationDescription::Select(desc) => {
1412 vec![&desc.tensor, &desc.indices, &desc.out]
1413 }
1414 NumericOperationDescription::SelectAssign(desc) => {
1415 vec![&desc.tensor, &desc.indices, &desc.value, &desc.out]
1416 }
1417 NumericOperationDescription::MaskWhere(desc) => {
1418 vec![&desc.tensor, &desc.mask, &desc.value, &desc.out]
1419 }
1420 NumericOperationDescription::MaskFill(desc) => {
1421 vec![&desc.tensor, &desc.mask, &desc.out]
1422 }
1423 NumericOperationDescription::EqualElem(desc) => {
1424 vec![&desc.lhs, &desc.out]
1425 }
1426 NumericOperationDescription::GreaterElem(desc) => {
1427 vec![&desc.lhs, &desc.out]
1428 }
1429 NumericOperationDescription::GreaterEqualElem(desc) => {
1430 vec![&desc.lhs, &desc.out]
1431 }
1432 NumericOperationDescription::LowerElem(desc) => {
1433 vec![&desc.lhs, &desc.out]
1434 }
1435 NumericOperationDescription::LowerEqualElem(desc) => {
1436 vec![&desc.lhs, &desc.out]
1437 }
1438 NumericOperationDescription::Greater(desc) => {
1439 vec![&desc.lhs, &desc.rhs, &desc.out]
1440 }
1441 NumericOperationDescription::GreaterEqual(desc) => {
1442 vec![&desc.lhs, &desc.rhs, &desc.out]
1443 }
1444 NumericOperationDescription::Lower(desc) => {
1445 vec![&desc.lhs, &desc.rhs, &desc.out]
1446 }
1447 NumericOperationDescription::LowerEqual(desc) => {
1448 vec![&desc.lhs, &desc.rhs, &desc.out]
1449 }
1450 NumericOperationDescription::ArgMax(desc) => {
1451 vec![&desc.lhs, &desc.out]
1452 }
1453 NumericOperationDescription::ArgMin(desc) => {
1454 vec![&desc.lhs, &desc.out]
1455 }
1456 NumericOperationDescription::Clamp(desc) => {
1457 vec![&desc.tensor, &desc.out]
1458 }
1459 NumericOperationDescription::Abs(desc) => {
1460 vec![&desc.input, &desc.out]
1461 }
1462 NumericOperationDescription::Zeros(desc) => vec![desc],
1463 NumericOperationDescription::Full(desc) => vec![&desc.0],
1464 NumericOperationDescription::MeanDim(desc) => {
1465 vec![&desc.lhs, &desc.out]
1466 }
1467 NumericOperationDescription::Mean(desc) => {
1468 vec![&desc.input, &desc.out]
1469 }
1470 NumericOperationDescription::Sum(desc) => {
1471 vec![&desc.input, &desc.out]
1472 }
1473 NumericOperationDescription::SumDim(desc) => {
1474 vec![&desc.lhs, &desc.out]
1475 }
1476 NumericOperationDescription::Prod(desc) => {
1477 vec![&desc.input, &desc.out]
1478 }
1479 NumericOperationDescription::ProdDim(desc) => {
1480 vec![&desc.lhs, &desc.out]
1481 }
1482 NumericOperationDescription::Max(desc) => {
1483 vec![&desc.input, &desc.out]
1484 }
1485 NumericOperationDescription::MaxDimWithIndices(desc) => {
1486 vec![&desc.tensor, &desc.out_indices, &desc.out]
1487 }
1488 NumericOperationDescription::MinDimWithIndices(desc) => {
1489 vec![&desc.tensor, &desc.out_indices, &desc.out]
1490 }
1491 NumericOperationDescription::Min(desc) => {
1492 vec![&desc.input, &desc.out]
1493 }
1494 NumericOperationDescription::MaxDim(desc) => {
1495 vec![&desc.lhs, &desc.out]
1496 }
1497 NumericOperationDescription::MinDim(desc) => {
1498 vec![&desc.lhs, &desc.out]
1499 }
1500 NumericOperationDescription::IntRandom(desc) => {
1501 vec![&desc.out]
1502 }
1503 NumericOperationDescription::Powf(desc) => {
1504 vec![&desc.lhs, &desc.rhs, &desc.out]
1505 }
1506 }
1507 }
1508}
1509
1510impl FloatOperationDescription {
1511 fn nodes(&self) -> Vec<&TensorDescription> {
1512 match self {
1513 FloatOperationDescription::Matmul(desc) => {
1514 vec![&desc.lhs, &desc.rhs, &desc.out]
1515 }
1516 FloatOperationDescription::Random(desc) => vec![&desc.out],
1517 FloatOperationDescription::Exp(desc) => vec![&desc.input, &desc.out],
1518 FloatOperationDescription::Log(desc) => vec![&desc.input, &desc.out],
1519 FloatOperationDescription::Log1p(desc) => vec![&desc.input, &desc.out],
1520 FloatOperationDescription::Erf(desc) => vec![&desc.input, &desc.out],
1521 FloatOperationDescription::Recip(desc) => vec![&desc.input, &desc.out],
1522 FloatOperationDescription::PowfScalar(desc) => vec![&desc.lhs, &desc.out],
1523 FloatOperationDescription::Sqrt(desc) => vec![&desc.input, &desc.out],
1524 FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out],
1525 FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
1526 FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
1527 FloatOperationDescription::Round(desc) => vec![&desc.input, &desc.out],
1528 FloatOperationDescription::Floor(desc) => vec![&desc.input, &desc.out],
1529 FloatOperationDescription::Ceil(desc) => vec![&desc.input, &desc.out],
1530 FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
1531 FloatOperationDescription::Quantize(desc) => {
1532 if let Some(offset) = &desc.qparams.offset {
1533 vec![&desc.tensor, &desc.qparams.scale, &offset, &desc.out]
1534 } else {
1535 vec![&desc.tensor, &desc.qparams.scale, &desc.out]
1536 }
1537 }
1538 FloatOperationDescription::Dequantize(desc) => vec![&desc.input, &desc.out],
1539 }
1540 }
1541}
1542
1543impl IntOperationDescription {
1544 fn nodes(&self) -> Vec<&TensorDescription> {
1545 match self {
1546 IntOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out],
1547 }
1548 }
1549}
1550
1551impl BoolOperationDescription {
1552 fn nodes(&self) -> Vec<&TensorDescription> {
1553 match self {
1554 BoolOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out],
1555 BoolOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
1556 BoolOperationDescription::Not(desc) => vec![&desc.input, &desc.out],
1557 }
1558 }
1559}
1560
1561impl ModuleOperationDescription {
1562 fn nodes(&self) -> Vec<&TensorDescription> {
1563 match self {
1564 ModuleOperationDescription::Embedding(desc) => {
1565 vec![&desc.weights, &desc.indices, &desc.out]
1566 }
1567 ModuleOperationDescription::EmbeddingBackward(desc) => {
1568 vec![&desc.weights, &desc.out_grad, &desc.indices, &desc.out]
1569 }
1570 ModuleOperationDescription::Conv1d(desc) => {
1571 if let Some(bias) = &desc.bias {
1572 vec![&desc.x, &desc.weight, &bias, &desc.out]
1573 } else {
1574 vec![&desc.x, &desc.weight, &desc.out]
1575 }
1576 }
1577 ModuleOperationDescription::Conv2d(desc) => {
1578 if let Some(bias) = &desc.bias {
1579 vec![&desc.x, &desc.weight, &bias, &desc.out]
1580 } else {
1581 vec![&desc.x, &desc.weight, &desc.out]
1582 }
1583 }
1584 ModuleOperationDescription::Conv3d(desc) => {
1585 if let Some(bias) = &desc.bias {
1586 vec![&desc.x, &desc.weight, &bias, &desc.out]
1587 } else {
1588 vec![&desc.x, &desc.weight, &desc.out]
1589 }
1590 }
1591 ModuleOperationDescription::DeformableConv2d(desc) => match (&desc.mask, &desc.bias) {
1592 (Some(mask), Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias],
1593 (Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
1594 (None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
1595 (None, None) => vec![&desc.x, &desc.offset, &desc.weight],
1596 },
1597 ModuleOperationDescription::DeformableConv2dBackward(desc) => {
1598 match (&desc.mask, &desc.bias) {
1599 (Some(mask), Some(bias)) => {
1600 vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias]
1601 }
1602 (Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
1603 (None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
1604 (None, None) => vec![&desc.x, &desc.offset, &desc.weight],
1605 }
1606 }
1607 ModuleOperationDescription::ConvTranspose1d(desc) => {
1608 if let Some(bias) = &desc.bias {
1609 vec![&desc.x, &desc.weight, &bias, &desc.out]
1610 } else {
1611 vec![&desc.x, &desc.weight, &desc.out]
1612 }
1613 }
1614 ModuleOperationDescription::ConvTranspose2d(desc) => {
1615 if let Some(bias) = &desc.bias {
1616 vec![&desc.x, &desc.weight, &bias, &desc.out]
1617 } else {
1618 vec![&desc.x, &desc.weight, &desc.out]
1619 }
1620 }
1621 ModuleOperationDescription::ConvTranspose3d(desc) => {
1622 if let Some(bias) = &desc.bias {
1623 vec![&desc.x, &desc.weight, &bias, &desc.out]
1624 } else {
1625 vec![&desc.x, &desc.weight, &desc.out]
1626 }
1627 }
1628 ModuleOperationDescription::AvgPool1d(desc) => {
1629 vec![&desc.x, &desc.out]
1630 }
1631 ModuleOperationDescription::AvgPool2d(desc) => {
1632 vec![&desc.x, &desc.out]
1633 }
1634 ModuleOperationDescription::AvgPool1dBackward(desc) => {
1635 vec![&desc.x, &desc.out, &desc.grad]
1636 }
1637 ModuleOperationDescription::AvgPool2dBackward(desc) => {
1638 vec![&desc.x, &desc.out, &desc.grad]
1639 }
1640 ModuleOperationDescription::AdaptiveAvgPool1d(desc) => {
1641 vec![&desc.x, &desc.out]
1642 }
1643 ModuleOperationDescription::AdaptiveAvgPool2d(desc) => {
1644 vec![&desc.x, &desc.out]
1645 }
1646 ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc) => {
1647 vec![&desc.x, &desc.out, &desc.grad]
1648 }
1649 ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc) => {
1650 vec![&desc.x, &desc.out, &desc.grad]
1651 }
1652 ModuleOperationDescription::MaxPool1d(desc) => {
1653 vec![&desc.x, &desc.out]
1654 }
1655 ModuleOperationDescription::MaxPool1dWithIndices(desc) => {
1656 vec![&desc.x, &desc.out, &desc.out_indices]
1657 }
1658 ModuleOperationDescription::MaxPool1dWithIndicesBackward(desc) => {
1659 vec![&desc.x, &desc.out, &desc.indices, &desc.grad]
1660 }
1661 ModuleOperationDescription::MaxPool2d(desc) => {
1662 vec![&desc.x, &desc.out]
1663 }
1664 ModuleOperationDescription::MaxPool2dWithIndices(desc) => {
1665 vec![&desc.x, &desc.out, &desc.out_indices]
1666 }
1667 ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc) => {
1668 vec![&desc.x, &desc.out, &desc.indices, &desc.grad]
1669 }
1670 ModuleOperationDescription::Interpolate(desc) => {
1671 vec![&desc.x, &desc.out]
1672 }
1673 ModuleOperationDescription::InterpolateBackward(desc) => {
1674 vec![&desc.x, &desc.out, &desc.grad]
1675 }
1676 }
1677 }
1678}
1679
1680impl core::hash::Hash for RandomOperationDescription {
1681 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1682 self.out.hash(state);
1683
1684 match self.distribution {
1685 Distribution::Default => 1u8.hash(state),
1686 Distribution::Bernoulli(_) => 2u8.hash(state),
1687 Distribution::Uniform(_, _) => 3u8.hash(state),
1688 Distribution::Normal(_, _) => 4u8.hash(state),
1689 }
1690 }
1691}
1692
1693impl<E> core::hash::Hash for ScalarOperationDescription<E> {
1694 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1695 self.lhs.hash(state);
1696 self.out.hash(state);
1697 }
1698}
1699
1700impl<E> core::hash::Hash for MaskFillOperationDescription<E> {
1701 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1702 self.tensor.hash(state);
1703 self.mask.hash(state);
1704 self.out.hash(state);
1705 }
1706}
1707
1708impl<E> core::hash::Hash for ClampOperationDescription<E> {
1709 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1710 self.tensor.hash(state);
1711 self.out.hash(state);
1712 }
1713}
1714
1715impl<E> core::hash::Hash for NumericOperationDescription<E> {
1716 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1717 match self {
1718 NumericOperationDescription::Add(desc) => desc.hash(state),
1719 NumericOperationDescription::AddScalar(desc) => desc.hash(state),
1720 NumericOperationDescription::Sub(desc) => desc.hash(state),
1721 NumericOperationDescription::SubScalar(desc) => desc.hash(state),
1722 NumericOperationDescription::Div(desc) => desc.hash(state),
1723 NumericOperationDescription::DivScalar(desc) => desc.hash(state),
1724 NumericOperationDescription::Rem(desc) => desc.hash(state),
1725 NumericOperationDescription::RemScalar(desc) => desc.hash(state),
1726 NumericOperationDescription::Mul(desc) => desc.hash(state),
1727 NumericOperationDescription::MulScalar(desc) => desc.hash(state),
1728 NumericOperationDescription::Abs(desc) => desc.hash(state),
1729 NumericOperationDescription::Ones(desc) => desc.hash(state),
1730 NumericOperationDescription::Zeros(desc) => desc.hash(state),
1731 NumericOperationDescription::Full(desc) => desc.0.hash(state),
1732 NumericOperationDescription::Gather(desc) => desc.hash(state),
1733 NumericOperationDescription::Scatter(desc) => desc.hash(state),
1734 NumericOperationDescription::Select(desc) => desc.hash(state),
1735 NumericOperationDescription::SelectAssign(desc) => desc.hash(state),
1736 NumericOperationDescription::MaskWhere(desc) => desc.hash(state),
1737 NumericOperationDescription::MaskFill(desc) => desc.hash(state),
1738 NumericOperationDescription::MeanDim(desc) => desc.hash(state),
1739 NumericOperationDescription::Mean(desc) => desc.hash(state),
1740 NumericOperationDescription::Sum(desc) => desc.hash(state),
1741 NumericOperationDescription::SumDim(desc) => desc.hash(state),
1742 NumericOperationDescription::Prod(desc) => desc.hash(state),
1743 NumericOperationDescription::ProdDim(desc) => desc.hash(state),
1744 NumericOperationDescription::EqualElem(desc) => desc.hash(state),
1745 NumericOperationDescription::Greater(desc) => desc.hash(state),
1746 NumericOperationDescription::GreaterElem(desc) => desc.hash(state),
1747 NumericOperationDescription::GreaterEqual(desc) => desc.hash(state),
1748 NumericOperationDescription::GreaterEqualElem(desc) => desc.hash(state),
1749 NumericOperationDescription::Lower(desc) => desc.hash(state),
1750 NumericOperationDescription::LowerElem(desc) => desc.hash(state),
1751 NumericOperationDescription::LowerEqual(desc) => desc.hash(state),
1752 NumericOperationDescription::LowerEqualElem(desc) => desc.hash(state),
1753 NumericOperationDescription::ArgMax(desc) => desc.hash(state),
1754 NumericOperationDescription::ArgMin(desc) => desc.hash(state),
1755 NumericOperationDescription::Max(desc) => desc.hash(state),
1756 NumericOperationDescription::MaxDimWithIndices(desc) => desc.hash(state),
1757 NumericOperationDescription::MinDimWithIndices(desc) => desc.hash(state),
1758 NumericOperationDescription::Min(desc) => desc.hash(state),
1759 NumericOperationDescription::MaxDim(desc) => desc.hash(state),
1760 NumericOperationDescription::MinDim(desc) => desc.hash(state),
1761 NumericOperationDescription::Clamp(desc) => desc.hash(state),
1762 NumericOperationDescription::IntRandom(desc) => desc.hash(state),
1763 NumericOperationDescription::Powf(desc) => desc.hash(state),
1764 }
1765 }
1766}