candle_core/
conv.rs

1//! 1D and 2D Convolutions
2//!
3use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct ParamsConv1D {
7    pub(crate) b_size: usize,
8    // Maybe we should have a version without l_in as this bit depends on the input and not only on
9    // the weights.
10    pub(crate) l_in: usize,
11    pub(crate) c_out: usize,
12    pub(crate) c_in: usize,
13    pub(crate) k_size: usize,
14    pub(crate) padding: usize,
15    pub(crate) stride: usize,
16    pub(crate) dilation: usize,
17    pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
18}
19
20impl ParamsConv1D {
21    pub(crate) fn l_out(&self) -> usize {
22        (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
23    }
24
25    pub(crate) fn out_dims(&self) -> Vec<usize> {
26        let l_out = self.l_out();
27        vec![self.b_size, self.c_out, l_out]
28    }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ParamsConvTranspose1D {
33    pub(crate) b_size: usize,
34    pub(crate) l_in: usize,
35    pub(crate) c_out: usize,
36    pub(crate) c_in: usize,
37    pub(crate) k_size: usize,
38    pub(crate) padding: usize,
39    pub(crate) output_padding: usize,
40    pub(crate) stride: usize,
41    pub(crate) dilation: usize,
42}
43
44impl ParamsConvTranspose1D {
45    pub(crate) fn l_out(&self) -> usize {
46        (self.l_in - 1) * self.stride - 2 * self.padding
47            + self.dilation * (self.k_size - 1)
48            + self.output_padding
49            + 1
50    }
51
52    pub(crate) fn out_dims(&self) -> Vec<usize> {
53        let l_out = self.l_out();
54        vec![self.b_size, self.c_out, l_out]
55    }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59pub enum CudnnFwdAlgo {
60    ImplicitGemm,
61    ImplicitPrecompGemm,
62    Gemm,
63    Direct,
64    Fft,
65    FftTiling,
66    Winograd,
67    WinogradNonFused,
68    Count,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct ParamsConv2D {
73    pub(crate) b_size: usize,
74    pub(crate) i_h: usize,
75    pub(crate) i_w: usize,
76    pub(crate) k_h: usize,
77    pub(crate) k_w: usize,
78    pub(crate) c_out: usize,
79    pub(crate) c_in: usize,
80    pub(crate) padding: usize,
81    pub(crate) stride: usize,
82    pub(crate) dilation: usize,
83    pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
84}
85
86impl ParamsConv2D {
87    pub(crate) fn out_h(&self) -> usize {
88        (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
89    }
90
91    pub(crate) fn out_w(&self) -> usize {
92        (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
93    }
94
95    pub(crate) fn out_dims(&self) -> Vec<usize> {
96        vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
97    }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct ParamsConvTranspose2D {
102    pub(crate) b_size: usize,
103    pub(crate) i_h: usize,
104    pub(crate) i_w: usize,
105    pub(crate) k_h: usize,
106    pub(crate) k_w: usize,
107    pub(crate) c_out: usize,
108    pub(crate) c_in: usize,
109    pub(crate) padding: usize,
110    pub(crate) output_padding: usize,
111    pub(crate) stride: usize,
112    pub(crate) dilation: usize,
113}
114
115impl ParamsConvTranspose2D {
116    pub(crate) fn out_h(&self) -> usize {
117        (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
118            - 2 * self.padding
119    }
120
121    pub(crate) fn out_w(&self) -> usize {
122        (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
123            - 2 * self.padding
124    }
125
126    pub(crate) fn out_dims(&self) -> Vec<usize> {
127        vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
128    }
129}
130
131impl Tensor {
132    fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
133        let storage =
134            self.storage()
135                .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
136        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
137            arg,
138            kernel,
139            padding: params.padding,
140            stride: params.stride,
141            dilation: params.dilation,
142        });
143        let out_dims = params.out_dims();
144        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
145    }
146
147    /// Applies a 1D convolution over the input tensor.
148    pub fn conv1d(
149        &self,
150        kernel: &Self,
151        padding: usize,
152        stride: usize,
153        dilation: usize,
154        groups: usize,
155    ) -> Result<Self> {
156        self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
157    }
158
159    /// Applies a 1D convolution over the input tensor.
160    pub fn conv1d_with_algo(
161        &self,
162        kernel: &Self,
163        padding: usize,
164        stride: usize,
165        dilation: usize,
166        groups: usize,
167        cudnn_fwd_algo: Option<CudnnFwdAlgo>,
168    ) -> Result<Self> {
169        let (c_out, c_in_k, k_size) = kernel.dims3()?;
170        let (b_size, c_in, l_in) = self.dims3()?;
171        if c_in != c_in_k * groups {
172            Err(Error::Conv1dInvalidArgs {
173                inp_shape: self.shape().clone(),
174                k_shape: kernel.shape().clone(),
175                padding,
176                stride,
177                msg: "the number of in-channels on the input doesn't match the kernel size",
178            }
179            .bt())?
180        }
181
182        let params = ParamsConv1D {
183            b_size,
184            l_in,
185            c_out: c_out / groups,
186            c_in: c_in / groups,
187            k_size,
188            padding,
189            stride,
190            dilation,
191            cudnn_fwd_algo,
192        };
193        if groups == 1 {
194            self.conv1d_single_group(kernel, &params)
195        } else {
196            let blocks = self.chunk(groups, 1)?;
197            let kernel = kernel.chunk(groups, 0)?;
198            let blocks = blocks
199                .iter()
200                .zip(&kernel)
201                .map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
202                .collect::<Result<Vec<_>>>()?;
203            Tensor::cat(&blocks, 1)
204        }
205    }
206
207    fn conv_transpose1d_single_group(
208        &self,
209        kernel: &Self,
210        params: &ParamsConvTranspose1D,
211    ) -> Result<Self> {
212        let storage = self.storage().conv_transpose1d(
213            self.layout(),
214            &kernel.storage(),
215            kernel.layout(),
216            params,
217        )?;
218        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
219            arg,
220            kernel,
221            padding: params.padding,
222            output_padding: params.output_padding,
223            stride: params.stride,
224            dilation: params.dilation,
225        });
226        let out_dims = params.out_dims();
227        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
228    }
229
230    /// Applies a 1D transposed convolution over the input tensor.
231    pub fn conv_transpose1d(
232        &self,
233        kernel: &Self,
234        padding: usize,
235        output_padding: usize,
236        stride: usize,
237        dilation: usize,
238        groups: usize,
239    ) -> Result<Self> {
240        let (c_in_k, c_out, k_size) = kernel.dims3()?;
241        let (b_size, c_in, l_in) = self.dims3()?;
242        if c_in != c_in_k {
243            crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
244        }
245        if c_in % groups != 0 {
246            crate::bail!("in_channel {c_in} is not divisible by the number of groups")
247        }
248        let params = ParamsConvTranspose1D {
249            b_size,
250            l_in,
251            k_size,
252            c_out,
253            c_in: c_in / groups,
254            padding,
255            output_padding,
256            stride,
257            dilation,
258        };
259        if groups == 1 {
260            self.conv_transpose1d_single_group(kernel, &params)
261        } else {
262            let blocks = self.chunk(groups, 1)?;
263            let kernel = kernel.chunk(groups, 0)?;
264            let blocks = blocks
265                .iter()
266                .zip(&kernel)
267                .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
268                .collect::<Result<Vec<_>>>()?;
269            Tensor::cat(&blocks, 1)
270        }
271    }
272
273    fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
274        let storage =
275            self.storage()
276                .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
277        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
278            arg,
279            kernel,
280            padding: params.padding,
281            stride: params.stride,
282            dilation: params.dilation,
283        });
284        let out_dims = params.out_dims();
285        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
286    }
287
288    /// Applies a 2D convolution over the input tensor.
289    pub fn conv2d(
290        &self,
291        kernel: &Self,
292        padding: usize,
293        stride: usize,
294        dilation: usize,
295        groups: usize,
296    ) -> Result<Self> {
297        self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
298    }
299
300    pub fn conv2d_with_algo(
301        &self,
302        kernel: &Self,
303        padding: usize,
304        stride: usize,
305        dilation: usize,
306        groups: usize,
307        cudnn_fwd_algo: Option<CudnnFwdAlgo>,
308    ) -> Result<Self> {
309        let (b_size, c_in, i_h, i_w) = self.dims4()?;
310        let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
311        if c_in != c_in_k * groups {
312            crate::bail!(
313                "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
314            )
315        }
316        let params = ParamsConv2D {
317            b_size,
318            i_h,
319            i_w,
320            k_h,
321            k_w,
322            c_out: c_out / groups,
323            c_in: c_in / groups,
324            padding,
325            stride,
326            dilation,
327            cudnn_fwd_algo,
328        };
329        if groups == 1 {
330            self.conv2d_single_group(kernel, &params)
331        } else {
332            let blocks = self.chunk(groups, 1)?;
333            let kernel = kernel.chunk(groups, 0)?;
334            let blocks = blocks
335                .iter()
336                .zip(&kernel)
337                .map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
338                .collect::<Result<Vec<_>>>()?;
339            Tensor::cat(&blocks, 1)
340        }
341    }
342
343    /// Applies a 2D transposed convolution over the input tensor.
344    pub fn conv_transpose2d(
345        &self,
346        kernel: &Self,
347        padding: usize,
348        output_padding: usize,
349        stride: usize,
350        dilation: usize,
351    ) -> Result<Self> {
352        let (b_size, c_in, i_h, i_w) = self.dims4()?;
353        let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
354        if c_in != c_in_k {
355            crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
356        }
357        let params = ParamsConvTranspose2D {
358            b_size,
359            i_h,
360            i_w,
361            k_h,
362            k_w,
363            c_out,
364            c_in,
365            padding,
366            output_padding,
367            stride,
368            dilation,
369        };
370        let storage = self.storage().conv_transpose2d(
371            self.layout(),
372            &kernel.storage(),
373            kernel.layout(),
374            &params,
375        )?;
376        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
377            arg,
378            kernel,
379            padding: params.padding,
380            output_padding: params.output_padding,
381            stride: params.stride,
382            dilation: params.dilation,
383        });
384        let out_dims = params.out_dims();
385        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
386    }
387}