1use crate::{Graph, NodeId, Op};
19
20impl Graph {
21 pub fn conv2d(
23 &mut self,
24 input: NodeId,
25 weight: NodeId,
26 kernel_size: [usize; 2],
27 stride: [usize; 2],
28 padding: [usize; 2],
29 dilation: [usize; 2],
30 groups: usize,
31 ) -> NodeId {
32 let in_s = self.node(input).shape.clone();
33 let w_s = self.node(weight).shape.clone();
34 let out = crate::shape::conv2d_output_shape(
35 &in_s,
36 &w_s,
37 kernel_size,
38 stride,
39 padding,
40 dilation,
41 groups,
42 )
43 .expect("conv2d shape inference");
44 self.push(
45 Op::Conv {
46 kernel_size: kernel_size.to_vec(),
47 stride: stride.to_vec(),
48 padding: padding.to_vec(),
49 dilation: dilation.to_vec(),
50 groups,
51 },
52 vec![input, weight],
53 out,
54 None,
55 )
56 }
57
58 pub fn im2col(
60 &mut self,
61 input: NodeId,
62 kernel_size: [usize; 2],
63 stride: [usize; 2],
64 padding: [usize; 2],
65 dilation: [usize; 2],
66 ) -> NodeId {
67 let in_s = self.node(input).shape.clone();
68 let out = crate::shape::im2col_output_shape(&in_s, kernel_size, stride, padding, dilation)
69 .expect("im2col shape inference");
70 self.push(
71 Op::Im2Col {
72 kernel_size: kernel_size.to_vec(),
73 stride: stride.to_vec(),
74 padding: padding.to_vec(),
75 dilation: dilation.to_vec(),
76 },
77 vec![input],
78 out,
79 None,
80 )
81 }
82
83 pub fn conv_transpose2d(
85 &mut self,
86 input: NodeId,
87 weight: NodeId,
88 kernel_size: [usize; 2],
89 stride: [usize; 2],
90 padding: [usize; 2],
91 dilation: [usize; 2],
92 output_padding: [usize; 2],
93 groups: usize,
94 ) -> NodeId {
95 let in_s = self.node(input).shape.clone();
96 let w_s = self.node(weight).shape.clone();
97 let out = crate::shape::conv_transpose2d_output_shape(
98 &in_s,
99 &w_s,
100 kernel_size,
101 stride,
102 padding,
103 dilation,
104 output_padding,
105 groups,
106 )
107 .expect("conv_transpose2d shape inference");
108 self.push(
109 Op::ConvTranspose2d {
110 kernel_size: kernel_size.to_vec(),
111 stride: stride.to_vec(),
112 padding: padding.to_vec(),
113 dilation: dilation.to_vec(),
114 output_padding: output_padding.to_vec(),
115 groups,
116 },
117 vec![input, weight],
118 out,
119 None,
120 )
121 }
122}