1use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub struct ParamsConv1D {
5 pub(crate) b_size: usize,
6 pub(crate) l_in: usize,
9 pub(crate) c_out: usize,
10 pub(crate) c_in: usize,
11 pub(crate) k_size: usize,
12 pub(crate) padding: usize,
13 pub(crate) stride: usize,
14 pub(crate) dilation: usize,
15}
16
17impl ParamsConv1D {
18 pub(crate) fn l_out(&self) -> usize {
19 (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
20 }
21
22 pub(crate) fn out_dims(&self) -> Vec<usize> {
23 let l_out = self.l_out();
24 vec![self.b_size, self.c_out, l_out]
25 }
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub enum CudnnFwdAlgo {
30 ImplicitGemm,
31 ImplicitPrecompGemm,
32 Gemm,
33 Direct,
34 Fft,
35 FftTiling,
36 Winograd,
37 WinogradNonFused,
38 Count,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct ParamsConv2D {
43 pub(crate) b_size: usize,
44 pub(crate) i_h: usize,
45 pub(crate) i_w: usize,
46 pub(crate) k_h: usize,
47 pub(crate) k_w: usize,
48 pub(crate) c_out: usize,
49 pub(crate) c_in: usize,
50 pub(crate) padding: usize,
51 pub(crate) stride: usize,
52 pub(crate) dilation: usize,
53 pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
54}
55
56impl ParamsConv2D {
57 pub(crate) fn out_h(&self) -> usize {
58 (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
59 }
60
61 pub(crate) fn out_w(&self) -> usize {
62 (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
63 }
64
65 pub(crate) fn out_dims(&self) -> Vec<usize> {
66 vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct ParamsConvTranspose2D {
72 pub(crate) b_size: usize,
73 pub(crate) i_h: usize,
74 pub(crate) i_w: usize,
75 pub(crate) k_h: usize,
76 pub(crate) k_w: usize,
77 pub(crate) c_out: usize,
78 pub(crate) c_in: usize,
79 pub(crate) padding: usize,
80 pub(crate) output_padding: usize,
81 pub(crate) stride: usize,
82 pub(crate) dilation: usize,
83}
84
85impl ParamsConvTranspose2D {
86 pub(crate) fn out_h(&self) -> usize {
87 (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
88 - 2 * self.padding
89 }
90
91 pub(crate) fn out_w(&self) -> usize {
92 (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
93 - 2 * self.padding
94 }
95
96 pub(crate) fn out_dims(&self) -> Vec<usize> {
97 vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
98 }
99}
100
101impl Tensor {
102 fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
103 let storage =
104 self.storage()
105 .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
106 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
107 arg,
108 kernel,
109 padding: params.padding,
110 stride: params.stride,
111 dilation: params.dilation,
112 });
113 let out_dims = params.out_dims();
114 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
115 }
116
117 pub fn conv1d(
119 &self,
120 kernel: &Self,
121 padding: usize,
122 stride: usize,
123 dilation: usize,
124 groups: usize,
125 ) -> Result<Self> {
126 let (c_out, c_in_k, k_size) = kernel.dims3()?;
127 let (b_size, c_in, l_in) = self.dims3()?;
128 if c_in != c_in_k * groups {
129 Err(Error::Conv1dInvalidArgs {
130 inp_shape: self.shape().clone(),
131 k_shape: kernel.shape().clone(),
132 padding,
133 stride,
134 msg: "the number of in-channels on the input doesn't match the kernel size",
135 }
136 .bt())?
137 }
138
139 let params = ParamsConv1D {
140 b_size,
141 l_in,
142 c_out: c_out / groups,
143 c_in: c_in / groups,
144 k_size,
145 padding,
146 stride,
147 dilation,
148 };
149 if groups == 1 {
150 self.conv1d_single_group(kernel, ¶ms)
151 } else {
152 let blocks = self.chunk(groups, 1)?;
153 let kernel = kernel.chunk(groups, 0)?;
154 let blocks = blocks
155 .iter()
156 .zip(&kernel)
157 .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
158 .collect::<Result<Vec<_>>>()?;
159 Tensor::cat(&blocks, 1)
160 }
161 }
162
163 fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
164 let storage =
165 self.storage()
166 .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
167 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
168 arg,
169 kernel,
170 padding: params.padding,
171 stride: params.stride,
172 dilation: params.dilation,
173 });
174 let out_dims = params.out_dims();
175 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
176 }
177
178 pub fn conv2d(
180 &self,
181 kernel: &Self,
182 padding: usize,
183 stride: usize,
184 dilation: usize,
185 groups: usize,
186 ) -> Result<Self> {
187 let (b_size, c_in, i_h, i_w) = self.dims4()?;
188 let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
189 if c_in != c_in_k * groups {
190 crate::bail!(
191 "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
192 )
193 }
194 let params = ParamsConv2D {
195 b_size,
196 i_h,
197 i_w,
198 k_h,
199 k_w,
200 c_out: c_out / groups,
201 c_in: c_in / groups,
202 padding,
203 stride,
204 dilation,
205 cudnn_fwd_algo: None,
206 };
207 if groups == 1 {
208 self.conv2d_single_group(kernel, ¶ms)
209 } else {
210 let blocks = self.chunk(groups, 1)?;
211 let kernel = kernel.chunk(groups, 0)?;
212 let blocks = blocks
213 .iter()
214 .zip(&kernel)
215 .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
216 .collect::<Result<Vec<_>>>()?;
217 Tensor::cat(&blocks, 1)
218 }
219 }
220
221 pub fn conv_transpose2d(
223 &self,
224 kernel: &Self,
225 padding: usize,
226 output_padding: usize,
227 stride: usize,
228 dilation: usize,
229 ) -> Result<Self> {
230 let (b_size, c_in, i_h, i_w) = self.dims4()?;
231 let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
232 if c_in != c_in_k {
233 crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
234 }
235 let params = ParamsConvTranspose2D {
236 b_size,
237 i_h,
238 i_w,
239 k_h,
240 k_w,
241 c_out,
242 c_in,
243 padding,
244 output_padding,
245 stride,
246 dilation,
247 };
248 let storage = self.storage().conv_transpose2d(
249 self.layout(),
250 &kernel.storage(),
251 kernel.layout(),
252 ¶ms,
253 )?;
254 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
255 arg,
256 kernel,
257 padding: params.padding,
258 output_padding: params.output_padding,
259 stride: params.stride,
260 dilation: params.dilation,
261 });
262 let out_dims = params.out_dims();
263 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
264 }
265}