1use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct ParamsConv1D {
7 pub(crate) b_size: usize,
8 pub(crate) l_in: usize,
11 pub(crate) c_out: usize,
12 pub(crate) c_in: usize,
13 pub(crate) k_size: usize,
14 pub(crate) padding: usize,
15 pub(crate) stride: usize,
16 pub(crate) dilation: usize,
17 pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
18}
19
20impl ParamsConv1D {
21 pub(crate) fn l_out(&self) -> usize {
22 (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
23 }
24
25 pub(crate) fn out_dims(&self) -> Vec<usize> {
26 let l_out = self.l_out();
27 vec![self.b_size, self.c_out, l_out]
28 }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ParamsConvTranspose1D {
33 pub(crate) b_size: usize,
34 pub(crate) l_in: usize,
35 pub(crate) c_out: usize,
36 pub(crate) c_in: usize,
37 pub(crate) k_size: usize,
38 pub(crate) padding: usize,
39 pub(crate) output_padding: usize,
40 pub(crate) stride: usize,
41 pub(crate) dilation: usize,
42}
43
44impl ParamsConvTranspose1D {
45 pub(crate) fn l_out(&self) -> usize {
46 (self.l_in - 1) * self.stride - 2 * self.padding
47 + self.dilation * (self.k_size - 1)
48 + self.output_padding
49 + 1
50 }
51
52 pub(crate) fn out_dims(&self) -> Vec<usize> {
53 let l_out = self.l_out();
54 vec![self.b_size, self.c_out, l_out]
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59pub enum CudnnFwdAlgo {
60 ImplicitGemm,
61 ImplicitPrecompGemm,
62 Gemm,
63 Direct,
64 Fft,
65 FftTiling,
66 Winograd,
67 WinogradNonFused,
68 Count,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct ParamsConv2D {
73 pub(crate) b_size: usize,
74 pub(crate) i_h: usize,
75 pub(crate) i_w: usize,
76 pub(crate) k_h: usize,
77 pub(crate) k_w: usize,
78 pub(crate) c_out: usize,
79 pub(crate) c_in: usize,
80 pub(crate) padding: usize,
81 pub(crate) stride: usize,
82 pub(crate) dilation: usize,
83 pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
84}
85
86impl ParamsConv2D {
87 pub(crate) fn out_h(&self) -> usize {
88 (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
89 }
90
91 pub(crate) fn out_w(&self) -> usize {
92 (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
93 }
94
95 pub(crate) fn out_dims(&self) -> Vec<usize> {
96 vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct ParamsConvTranspose2D {
102 pub(crate) b_size: usize,
103 pub(crate) i_h: usize,
104 pub(crate) i_w: usize,
105 pub(crate) k_h: usize,
106 pub(crate) k_w: usize,
107 pub(crate) c_out: usize,
108 pub(crate) c_in: usize,
109 pub(crate) padding: usize,
110 pub(crate) output_padding: usize,
111 pub(crate) stride: usize,
112 pub(crate) dilation: usize,
113}
114
115impl ParamsConvTranspose2D {
116 pub(crate) fn out_h(&self) -> usize {
117 (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
118 - 2 * self.padding
119 }
120
121 pub(crate) fn out_w(&self) -> usize {
122 (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
123 - 2 * self.padding
124 }
125
126 pub(crate) fn out_dims(&self) -> Vec<usize> {
127 vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
128 }
129}
130
131impl Tensor {
132 fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
133 let storage =
134 self.storage()
135 .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
136 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
137 arg,
138 kernel,
139 padding: params.padding,
140 stride: params.stride,
141 dilation: params.dilation,
142 });
143 let out_dims = params.out_dims();
144 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
145 }
146
147 pub fn conv1d(
149 &self,
150 kernel: &Self,
151 padding: usize,
152 stride: usize,
153 dilation: usize,
154 groups: usize,
155 ) -> Result<Self> {
156 self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
157 }
158
159 pub fn conv1d_with_algo(
161 &self,
162 kernel: &Self,
163 padding: usize,
164 stride: usize,
165 dilation: usize,
166 groups: usize,
167 cudnn_fwd_algo: Option<CudnnFwdAlgo>,
168 ) -> Result<Self> {
169 let (c_out, c_in_k, k_size) = kernel.dims3()?;
170 let (b_size, c_in, l_in) = self.dims3()?;
171 if c_in != c_in_k * groups {
172 Err(Error::Conv1dInvalidArgs {
173 inp_shape: self.shape().clone(),
174 k_shape: kernel.shape().clone(),
175 padding,
176 stride,
177 msg: "the number of in-channels on the input doesn't match the kernel size",
178 }
179 .bt())?
180 }
181
182 let params = ParamsConv1D {
183 b_size,
184 l_in,
185 c_out: c_out / groups,
186 c_in: c_in / groups,
187 k_size,
188 padding,
189 stride,
190 dilation,
191 cudnn_fwd_algo,
192 };
193 if groups == 1 {
194 self.conv1d_single_group(kernel, ¶ms)
195 } else {
196 let blocks = self.chunk(groups, 1)?;
197 let kernel = kernel.chunk(groups, 0)?;
198 let blocks = blocks
199 .iter()
200 .zip(&kernel)
201 .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
202 .collect::<Result<Vec<_>>>()?;
203 Tensor::cat(&blocks, 1)
204 }
205 }
206
207 fn conv_transpose1d_single_group(
208 &self,
209 kernel: &Self,
210 params: &ParamsConvTranspose1D,
211 ) -> Result<Self> {
212 let storage = self.storage().conv_transpose1d(
213 self.layout(),
214 &kernel.storage(),
215 kernel.layout(),
216 params,
217 )?;
218 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
219 arg,
220 kernel,
221 padding: params.padding,
222 output_padding: params.output_padding,
223 stride: params.stride,
224 dilation: params.dilation,
225 });
226 let out_dims = params.out_dims();
227 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
228 }
229
230 pub fn conv_transpose1d(
232 &self,
233 kernel: &Self,
234 padding: usize,
235 output_padding: usize,
236 stride: usize,
237 dilation: usize,
238 groups: usize,
239 ) -> Result<Self> {
240 let (c_in_k, c_out, k_size) = kernel.dims3()?;
241 let (b_size, c_in, l_in) = self.dims3()?;
242 if c_in != c_in_k {
243 crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
244 }
245 if c_in % groups != 0 {
246 crate::bail!("in_channel {c_in} is not divisible by the number of groups")
247 }
248 let params = ParamsConvTranspose1D {
249 b_size,
250 l_in,
251 k_size,
252 c_out,
253 c_in: c_in / groups,
254 padding,
255 output_padding,
256 stride,
257 dilation,
258 };
259 if groups == 1 {
260 self.conv_transpose1d_single_group(kernel, ¶ms)
261 } else {
262 let blocks = self.chunk(groups, 1)?;
263 let kernel = kernel.chunk(groups, 0)?;
264 let blocks = blocks
265 .iter()
266 .zip(&kernel)
267 .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
268 .collect::<Result<Vec<_>>>()?;
269 Tensor::cat(&blocks, 1)
270 }
271 }
272
273 fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
274 let storage =
275 self.storage()
276 .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
277 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
278 arg,
279 kernel,
280 padding: params.padding,
281 stride: params.stride,
282 dilation: params.dilation,
283 });
284 let out_dims = params.out_dims();
285 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
286 }
287
288 pub fn conv2d(
290 &self,
291 kernel: &Self,
292 padding: usize,
293 stride: usize,
294 dilation: usize,
295 groups: usize,
296 ) -> Result<Self> {
297 self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
298 }
299
300 pub fn conv2d_with_algo(
301 &self,
302 kernel: &Self,
303 padding: usize,
304 stride: usize,
305 dilation: usize,
306 groups: usize,
307 cudnn_fwd_algo: Option<CudnnFwdAlgo>,
308 ) -> Result<Self> {
309 let (b_size, c_in, i_h, i_w) = self.dims4()?;
310 let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
311 if c_in != c_in_k * groups {
312 crate::bail!(
313 "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
314 )
315 }
316 let params = ParamsConv2D {
317 b_size,
318 i_h,
319 i_w,
320 k_h,
321 k_w,
322 c_out: c_out / groups,
323 c_in: c_in / groups,
324 padding,
325 stride,
326 dilation,
327 cudnn_fwd_algo,
328 };
329 if groups == 1 {
330 self.conv2d_single_group(kernel, ¶ms)
331 } else {
332 let blocks = self.chunk(groups, 1)?;
333 let kernel = kernel.chunk(groups, 0)?;
334 let blocks = blocks
335 .iter()
336 .zip(&kernel)
337 .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
338 .collect::<Result<Vec<_>>>()?;
339 Tensor::cat(&blocks, 1)
340 }
341 }
342
343 pub fn conv_transpose2d(
345 &self,
346 kernel: &Self,
347 padding: usize,
348 output_padding: usize,
349 stride: usize,
350 dilation: usize,
351 ) -> Result<Self> {
352 let (b_size, c_in, i_h, i_w) = self.dims4()?;
353 let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
354 if c_in != c_in_k {
355 crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
356 }
357 let params = ParamsConvTranspose2D {
358 b_size,
359 i_h,
360 i_w,
361 k_h,
362 k_w,
363 c_out,
364 c_in,
365 padding,
366 output_padding,
367 stride,
368 dilation,
369 };
370 let storage = self.storage().conv_transpose2d(
371 self.layout(),
372 &kernel.storage(),
373 kernel.layout(),
374 ¶ms,
375 )?;
376 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
377 arg,
378 kernel,
379 padding: params.padding,
380 output_padding: params.output_padding,
381 stride: params.stride,
382 dilation: params.dilation,
383 });
384 let out_dims = params.out_dims();
385 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
386 }
387}