1use crate::{Graph, NodeId, Op};
7
8impl Graph {
9 pub fn conv2d(
11 &mut self,
12 input: NodeId,
13 weight: NodeId,
14 kernel_size: [usize; 2],
15 stride: [usize; 2],
16 padding: [usize; 2],
17 dilation: [usize; 2],
18 groups: usize,
19 ) -> NodeId {
20 let in_s = self.node(input).shape.clone();
21 let w_s = self.node(weight).shape.clone();
22 let out = crate::shape::conv2d_output_shape(
23 &in_s,
24 &w_s,
25 kernel_size,
26 stride,
27 padding,
28 dilation,
29 groups,
30 )
31 .expect("conv2d shape inference");
32 self.push(
33 Op::Conv {
34 kernel_size: kernel_size.to_vec(),
35 stride: stride.to_vec(),
36 padding: padding.to_vec(),
37 dilation: dilation.to_vec(),
38 groups,
39 },
40 vec![input, weight],
41 out,
42 None,
43 )
44 }
45
46 pub fn conv_transpose2d(
48 &mut self,
49 input: NodeId,
50 weight: NodeId,
51 kernel_size: [usize; 2],
52 stride: [usize; 2],
53 padding: [usize; 2],
54 dilation: [usize; 2],
55 output_padding: [usize; 2],
56 groups: usize,
57 ) -> NodeId {
58 let in_s = self.node(input).shape.clone();
59 let w_s = self.node(weight).shape.clone();
60 let out = crate::shape::conv_transpose2d_output_shape(
61 &in_s,
62 &w_s,
63 kernel_size,
64 stride,
65 padding,
66 dilation,
67 output_padding,
68 groups,
69 )
70 .expect("conv_transpose2d shape inference");
71 self.push(
72 Op::ConvTranspose2d {
73 kernel_size: kernel_size.to_vec(),
74 stride: stride.to_vec(),
75 padding: padding.to_vec(),
76 dilation: dilation.to_vec(),
77 output_padding: output_padding.to_vec(),
78 groups,
79 },
80 vec![input, weight],
81 out,
82 None,
83 )
84 }
85}