candle_core/
conv.rs

1//! 1D and 2D Convolutions
2//!
3use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct ParamsConv1D {
7    pub(crate) b_size: usize,
8    // Maybe we should have a version without l_in as this bit depends on the input and not only on
9    // the weights.
10    pub(crate) l_in: usize,
11    pub(crate) c_out: usize,
12    pub(crate) c_in: usize,
13    pub(crate) k_size: usize,
14    pub(crate) padding: usize,
15    pub(crate) stride: usize,
16    pub(crate) dilation: usize,
17}
18
19impl ParamsConv1D {
20    pub(crate) fn l_out(&self) -> usize {
21        (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
22    }
23
24    pub(crate) fn out_dims(&self) -> Vec<usize> {
25        let l_out = self.l_out();
26        vec![self.b_size, self.c_out, l_out]
27    }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct ParamsConvTranspose1D {
32    pub(crate) b_size: usize,
33    pub(crate) l_in: usize,
34    pub(crate) c_out: usize,
35    pub(crate) c_in: usize,
36    pub(crate) k_size: usize,
37    pub(crate) padding: usize,
38    pub(crate) output_padding: usize,
39    pub(crate) stride: usize,
40    pub(crate) dilation: usize,
41}
42
43impl ParamsConvTranspose1D {
44    pub(crate) fn l_out(&self) -> usize {
45        (self.l_in - 1) * self.stride - 2 * self.padding
46            + self.dilation * (self.k_size - 1)
47            + self.output_padding
48            + 1
49    }
50
51    pub(crate) fn out_dims(&self) -> Vec<usize> {
52        let l_out = self.l_out();
53        vec![self.b_size, self.c_out, l_out]
54    }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub enum CudnnFwdAlgo {
59    ImplicitGemm,
60    ImplicitPrecompGemm,
61    Gemm,
62    Direct,
63    Fft,
64    FftTiling,
65    Winograd,
66    WinogradNonFused,
67    Count,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct ParamsConv2D {
72    pub(crate) b_size: usize,
73    pub(crate) i_h: usize,
74    pub(crate) i_w: usize,
75    pub(crate) k_h: usize,
76    pub(crate) k_w: usize,
77    pub(crate) c_out: usize,
78    pub(crate) c_in: usize,
79    pub(crate) padding: usize,
80    pub(crate) stride: usize,
81    pub(crate) dilation: usize,
82    pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
83}
84
85impl ParamsConv2D {
86    pub(crate) fn out_h(&self) -> usize {
87        (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
88    }
89
90    pub(crate) fn out_w(&self) -> usize {
91        (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
92    }
93
94    pub(crate) fn out_dims(&self) -> Vec<usize> {
95        vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
96    }
97}
98
99#[derive(Debug, Clone, PartialEq, Eq)]
100pub struct ParamsConvTranspose2D {
101    pub(crate) b_size: usize,
102    pub(crate) i_h: usize,
103    pub(crate) i_w: usize,
104    pub(crate) k_h: usize,
105    pub(crate) k_w: usize,
106    pub(crate) c_out: usize,
107    pub(crate) c_in: usize,
108    pub(crate) padding: usize,
109    pub(crate) output_padding: usize,
110    pub(crate) stride: usize,
111    pub(crate) dilation: usize,
112}
113
114impl ParamsConvTranspose2D {
115    pub(crate) fn out_h(&self) -> usize {
116        (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
117            - 2 * self.padding
118    }
119
120    pub(crate) fn out_w(&self) -> usize {
121        (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
122            - 2 * self.padding
123    }
124
125    pub(crate) fn out_dims(&self) -> Vec<usize> {
126        vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
127    }
128}
129
130impl Tensor {
131    fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
132        let storage =
133            self.storage()
134                .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
135        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
136            arg,
137            kernel,
138            padding: params.padding,
139            stride: params.stride,
140            dilation: params.dilation,
141        });
142        let out_dims = params.out_dims();
143        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
144    }
145
146    /// Applies a 1D convolution over the input tensor.
147    pub fn conv1d(
148        &self,
149        kernel: &Self,
150        padding: usize,
151        stride: usize,
152        dilation: usize,
153        groups: usize,
154    ) -> Result<Self> {
155        let (c_out, c_in_k, k_size) = kernel.dims3()?;
156        let (b_size, c_in, l_in) = self.dims3()?;
157        if c_in != c_in_k * groups {
158            Err(Error::Conv1dInvalidArgs {
159                inp_shape: self.shape().clone(),
160                k_shape: kernel.shape().clone(),
161                padding,
162                stride,
163                msg: "the number of in-channels on the input doesn't match the kernel size",
164            }
165            .bt())?
166        }
167
168        let params = ParamsConv1D {
169            b_size,
170            l_in,
171            c_out: c_out / groups,
172            c_in: c_in / groups,
173            k_size,
174            padding,
175            stride,
176            dilation,
177        };
178        if groups == 1 {
179            self.conv1d_single_group(kernel, &params)
180        } else {
181            let blocks = self.chunk(groups, 1)?;
182            let kernel = kernel.chunk(groups, 0)?;
183            let blocks = blocks
184                .iter()
185                .zip(&kernel)
186                .map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
187                .collect::<Result<Vec<_>>>()?;
188            Tensor::cat(&blocks, 1)
189        }
190    }
191
192    fn conv_transpose1d_single_group(
193        &self,
194        kernel: &Self,
195        params: &ParamsConvTranspose1D,
196    ) -> Result<Self> {
197        let storage = self.storage().conv_transpose1d(
198            self.layout(),
199            &kernel.storage(),
200            kernel.layout(),
201            params,
202        )?;
203        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
204            arg,
205            kernel,
206            padding: params.padding,
207            output_padding: params.output_padding,
208            stride: params.stride,
209            dilation: params.dilation,
210        });
211        let out_dims = params.out_dims();
212        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
213    }
214
215    /// Applies a 1D transposed convolution over the input tensor.
216    pub fn conv_transpose1d(
217        &self,
218        kernel: &Self,
219        padding: usize,
220        output_padding: usize,
221        stride: usize,
222        dilation: usize,
223        groups: usize,
224    ) -> Result<Self> {
225        let (c_in_k, c_out, k_size) = kernel.dims3()?;
226        let (b_size, c_in, l_in) = self.dims3()?;
227        if c_in != c_in_k {
228            crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
229        }
230        if c_in % groups != 0 {
231            crate::bail!("in_channel {c_in} is not divisible by the number of groups")
232        }
233        let params = ParamsConvTranspose1D {
234            b_size,
235            l_in,
236            k_size,
237            c_out,
238            c_in: c_in / groups,
239            padding,
240            output_padding,
241            stride,
242            dilation,
243        };
244        if groups == 1 {
245            self.conv_transpose1d_single_group(kernel, &params)
246        } else {
247            let blocks = self.chunk(groups, 1)?;
248            let kernel = kernel.chunk(groups, 0)?;
249            let blocks = blocks
250                .iter()
251                .zip(&kernel)
252                .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
253                .collect::<Result<Vec<_>>>()?;
254            Tensor::cat(&blocks, 1)
255        }
256    }
257
258    fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
259        let storage =
260            self.storage()
261                .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
262        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
263            arg,
264            kernel,
265            padding: params.padding,
266            stride: params.stride,
267            dilation: params.dilation,
268        });
269        let out_dims = params.out_dims();
270        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
271    }
272
273    /// Applies a 2D convolution over the input tensor.
274    pub fn conv2d(
275        &self,
276        kernel: &Self,
277        padding: usize,
278        stride: usize,
279        dilation: usize,
280        groups: usize,
281    ) -> Result<Self> {
282        let (b_size, c_in, i_h, i_w) = self.dims4()?;
283        let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
284        if c_in != c_in_k * groups {
285            crate::bail!(
286                "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
287            )
288        }
289        let params = ParamsConv2D {
290            b_size,
291            i_h,
292            i_w,
293            k_h,
294            k_w,
295            c_out: c_out / groups,
296            c_in: c_in / groups,
297            padding,
298            stride,
299            dilation,
300            cudnn_fwd_algo: None,
301        };
302        if groups == 1 {
303            self.conv2d_single_group(kernel, &params)
304        } else {
305            let blocks = self.chunk(groups, 1)?;
306            let kernel = kernel.chunk(groups, 0)?;
307            let blocks = blocks
308                .iter()
309                .zip(&kernel)
310                .map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
311                .collect::<Result<Vec<_>>>()?;
312            Tensor::cat(&blocks, 1)
313        }
314    }
315
316    /// Applies a 2D transposed convolution over the input tensor.
317    pub fn conv_transpose2d(
318        &self,
319        kernel: &Self,
320        padding: usize,
321        output_padding: usize,
322        stride: usize,
323        dilation: usize,
324    ) -> Result<Self> {
325        let (b_size, c_in, i_h, i_w) = self.dims4()?;
326        let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
327        if c_in != c_in_k {
328            crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
329        }
330        let params = ParamsConvTranspose2D {
331            b_size,
332            i_h,
333            i_w,
334            k_h,
335            k_w,
336            c_out,
337            c_in,
338            padding,
339            output_padding,
340            stride,
341            dilation,
342        };
343        let storage = self.storage().conv_transpose2d(
344            self.layout(),
345            &kernel.storage(),
346            kernel.layout(),
347            &params,
348        )?;
349        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
350            arg,
351            kernel,
352            padding: params.padding,
353            output_padding: params.output_padding,
354            stride: params.stride,
355            dilation: params.dilation,
356        });
357        let out_dims = params.out_dims();
358        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
359    }
360}