1use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct ParamsConv1D {
7 pub(crate) b_size: usize,
8 pub(crate) l_in: usize,
11 pub(crate) c_out: usize,
12 pub(crate) c_in: usize,
13 pub(crate) k_size: usize,
14 pub(crate) padding: usize,
15 pub(crate) stride: usize,
16 pub(crate) dilation: usize,
17}
18
19impl ParamsConv1D {
20 pub(crate) fn l_out(&self) -> usize {
21 (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
22 }
23
24 pub(crate) fn out_dims(&self) -> Vec<usize> {
25 let l_out = self.l_out();
26 vec![self.b_size, self.c_out, l_out]
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct ParamsConvTranspose1D {
32 pub(crate) b_size: usize,
33 pub(crate) l_in: usize,
34 pub(crate) c_out: usize,
35 pub(crate) c_in: usize,
36 pub(crate) k_size: usize,
37 pub(crate) padding: usize,
38 pub(crate) output_padding: usize,
39 pub(crate) stride: usize,
40 pub(crate) dilation: usize,
41}
42
43impl ParamsConvTranspose1D {
44 pub(crate) fn l_out(&self) -> usize {
45 (self.l_in - 1) * self.stride - 2 * self.padding
46 + self.dilation * (self.k_size - 1)
47 + self.output_padding
48 + 1
49 }
50
51 pub(crate) fn out_dims(&self) -> Vec<usize> {
52 let l_out = self.l_out();
53 vec![self.b_size, self.c_out, l_out]
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub enum CudnnFwdAlgo {
59 ImplicitGemm,
60 ImplicitPrecompGemm,
61 Gemm,
62 Direct,
63 Fft,
64 FftTiling,
65 Winograd,
66 WinogradNonFused,
67 Count,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct ParamsConv2D {
72 pub(crate) b_size: usize,
73 pub(crate) i_h: usize,
74 pub(crate) i_w: usize,
75 pub(crate) k_h: usize,
76 pub(crate) k_w: usize,
77 pub(crate) c_out: usize,
78 pub(crate) c_in: usize,
79 pub(crate) padding: usize,
80 pub(crate) stride: usize,
81 pub(crate) dilation: usize,
82 pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
83}
84
85impl ParamsConv2D {
86 pub(crate) fn out_h(&self) -> usize {
87 (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
88 }
89
90 pub(crate) fn out_w(&self) -> usize {
91 (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
92 }
93
94 pub(crate) fn out_dims(&self) -> Vec<usize> {
95 vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
96 }
97}
98
99#[derive(Debug, Clone, PartialEq, Eq)]
100pub struct ParamsConvTranspose2D {
101 pub(crate) b_size: usize,
102 pub(crate) i_h: usize,
103 pub(crate) i_w: usize,
104 pub(crate) k_h: usize,
105 pub(crate) k_w: usize,
106 pub(crate) c_out: usize,
107 pub(crate) c_in: usize,
108 pub(crate) padding: usize,
109 pub(crate) output_padding: usize,
110 pub(crate) stride: usize,
111 pub(crate) dilation: usize,
112}
113
114impl ParamsConvTranspose2D {
115 pub(crate) fn out_h(&self) -> usize {
116 (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
117 - 2 * self.padding
118 }
119
120 pub(crate) fn out_w(&self) -> usize {
121 (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
122 - 2 * self.padding
123 }
124
125 pub(crate) fn out_dims(&self) -> Vec<usize> {
126 vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
127 }
128}
129
130impl Tensor {
131 fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
132 let storage =
133 self.storage()
134 .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
135 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
136 arg,
137 kernel,
138 padding: params.padding,
139 stride: params.stride,
140 dilation: params.dilation,
141 });
142 let out_dims = params.out_dims();
143 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
144 }
145
146 pub fn conv1d(
148 &self,
149 kernel: &Self,
150 padding: usize,
151 stride: usize,
152 dilation: usize,
153 groups: usize,
154 ) -> Result<Self> {
155 let (c_out, c_in_k, k_size) = kernel.dims3()?;
156 let (b_size, c_in, l_in) = self.dims3()?;
157 if c_in != c_in_k * groups {
158 Err(Error::Conv1dInvalidArgs {
159 inp_shape: self.shape().clone(),
160 k_shape: kernel.shape().clone(),
161 padding,
162 stride,
163 msg: "the number of in-channels on the input doesn't match the kernel size",
164 }
165 .bt())?
166 }
167
168 let params = ParamsConv1D {
169 b_size,
170 l_in,
171 c_out: c_out / groups,
172 c_in: c_in / groups,
173 k_size,
174 padding,
175 stride,
176 dilation,
177 };
178 if groups == 1 {
179 self.conv1d_single_group(kernel, ¶ms)
180 } else {
181 let blocks = self.chunk(groups, 1)?;
182 let kernel = kernel.chunk(groups, 0)?;
183 let blocks = blocks
184 .iter()
185 .zip(&kernel)
186 .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
187 .collect::<Result<Vec<_>>>()?;
188 Tensor::cat(&blocks, 1)
189 }
190 }
191
192 fn conv_transpose1d_single_group(
193 &self,
194 kernel: &Self,
195 params: &ParamsConvTranspose1D,
196 ) -> Result<Self> {
197 let storage = self.storage().conv_transpose1d(
198 self.layout(),
199 &kernel.storage(),
200 kernel.layout(),
201 params,
202 )?;
203 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
204 arg,
205 kernel,
206 padding: params.padding,
207 output_padding: params.output_padding,
208 stride: params.stride,
209 dilation: params.dilation,
210 });
211 let out_dims = params.out_dims();
212 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
213 }
214
215 pub fn conv_transpose1d(
217 &self,
218 kernel: &Self,
219 padding: usize,
220 output_padding: usize,
221 stride: usize,
222 dilation: usize,
223 groups: usize,
224 ) -> Result<Self> {
225 let (c_in_k, c_out, k_size) = kernel.dims3()?;
226 let (b_size, c_in, l_in) = self.dims3()?;
227 if c_in != c_in_k {
228 crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
229 }
230 if c_in % groups != 0 {
231 crate::bail!("in_channel {c_in} is not divisible by the number of groups")
232 }
233 let params = ParamsConvTranspose1D {
234 b_size,
235 l_in,
236 k_size,
237 c_out,
238 c_in: c_in / groups,
239 padding,
240 output_padding,
241 stride,
242 dilation,
243 };
244 if groups == 1 {
245 self.conv_transpose1d_single_group(kernel, ¶ms)
246 } else {
247 let blocks = self.chunk(groups, 1)?;
248 let kernel = kernel.chunk(groups, 0)?;
249 let blocks = blocks
250 .iter()
251 .zip(&kernel)
252 .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
253 .collect::<Result<Vec<_>>>()?;
254 Tensor::cat(&blocks, 1)
255 }
256 }
257
258 fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
259 let storage =
260 self.storage()
261 .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
262 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
263 arg,
264 kernel,
265 padding: params.padding,
266 stride: params.stride,
267 dilation: params.dilation,
268 });
269 let out_dims = params.out_dims();
270 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
271 }
272
273 pub fn conv2d(
275 &self,
276 kernel: &Self,
277 padding: usize,
278 stride: usize,
279 dilation: usize,
280 groups: usize,
281 ) -> Result<Self> {
282 let (b_size, c_in, i_h, i_w) = self.dims4()?;
283 let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
284 if c_in != c_in_k * groups {
285 crate::bail!(
286 "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
287 )
288 }
289 let params = ParamsConv2D {
290 b_size,
291 i_h,
292 i_w,
293 k_h,
294 k_w,
295 c_out: c_out / groups,
296 c_in: c_in / groups,
297 padding,
298 stride,
299 dilation,
300 cudnn_fwd_algo: None,
301 };
302 if groups == 1 {
303 self.conv2d_single_group(kernel, ¶ms)
304 } else {
305 let blocks = self.chunk(groups, 1)?;
306 let kernel = kernel.chunk(groups, 0)?;
307 let blocks = blocks
308 .iter()
309 .zip(&kernel)
310 .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
311 .collect::<Result<Vec<_>>>()?;
312 Tensor::cat(&blocks, 1)
313 }
314 }
315
316 pub fn conv_transpose2d(
318 &self,
319 kernel: &Self,
320 padding: usize,
321 output_padding: usize,
322 stride: usize,
323 dilation: usize,
324 ) -> Result<Self> {
325 let (b_size, c_in, i_h, i_w) = self.dims4()?;
326 let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
327 if c_in != c_in_k {
328 crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
329 }
330 let params = ParamsConvTranspose2D {
331 b_size,
332 i_h,
333 i_w,
334 k_h,
335 k_w,
336 c_out,
337 c_in,
338 padding,
339 output_padding,
340 stride,
341 dilation,
342 };
343 let storage = self.storage().conv_transpose2d(
344 self.layout(),
345 &kernel.storage(),
346 kernel.layout(),
347 ¶ms,
348 )?;
349 let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
350 arg,
351 kernel,
352 padding: params.padding,
353 output_padding: params.output_padding,
354 stride: params.stride,
355 dilation: params.dilation,
356 });
357 let out_dims = params.out_dims();
358 Ok(crate::tensor::from_storage(storage, out_dims, op, false))
359 }
360}