candle_core_temp/
conv.rs

1use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub struct ParamsConv1D {
5    pub(crate) b_size: usize,
6    // Maybe we should have a version without l_in as this bit depends on the input and not only on
7    // the weights.
8    pub(crate) l_in: usize,
9    pub(crate) c_out: usize,
10    pub(crate) c_in: usize,
11    pub(crate) k_size: usize,
12    pub(crate) padding: usize,
13    pub(crate) stride: usize,
14    pub(crate) dilation: usize,
15}
16
17impl ParamsConv1D {
18    pub(crate) fn l_out(&self) -> usize {
19        (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
20    }
21
22    pub(crate) fn out_dims(&self) -> Vec<usize> {
23        let l_out = self.l_out();
24        vec![self.b_size, self.c_out, l_out]
25    }
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub enum CudnnFwdAlgo {
30    ImplicitGemm,
31    ImplicitPrecompGemm,
32    Gemm,
33    Direct,
34    Fft,
35    FftTiling,
36    Winograd,
37    WinogradNonFused,
38    Count,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct ParamsConv2D {
43    pub(crate) b_size: usize,
44    pub(crate) i_h: usize,
45    pub(crate) i_w: usize,
46    pub(crate) k_h: usize,
47    pub(crate) k_w: usize,
48    pub(crate) c_out: usize,
49    pub(crate) c_in: usize,
50    pub(crate) padding: usize,
51    pub(crate) stride: usize,
52    pub(crate) dilation: usize,
53    pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
54}
55
56impl ParamsConv2D {
57    pub(crate) fn out_h(&self) -> usize {
58        (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
59    }
60
61    pub(crate) fn out_w(&self) -> usize {
62        (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
63    }
64
65    pub(crate) fn out_dims(&self) -> Vec<usize> {
66        vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
67    }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct ParamsConvTranspose2D {
72    pub(crate) b_size: usize,
73    pub(crate) i_h: usize,
74    pub(crate) i_w: usize,
75    pub(crate) k_h: usize,
76    pub(crate) k_w: usize,
77    pub(crate) c_out: usize,
78    pub(crate) c_in: usize,
79    pub(crate) padding: usize,
80    pub(crate) output_padding: usize,
81    pub(crate) stride: usize,
82    pub(crate) dilation: usize,
83}
84
85impl ParamsConvTranspose2D {
86    pub(crate) fn out_h(&self) -> usize {
87        (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
88            - 2 * self.padding
89    }
90
91    pub(crate) fn out_w(&self) -> usize {
92        (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
93            - 2 * self.padding
94    }
95
96    pub(crate) fn out_dims(&self) -> Vec<usize> {
97        vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
98    }
99}
100
101impl Tensor {
102    fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
103        let storage =
104            self.storage()
105                .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
106        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
107            arg,
108            kernel,
109            padding: params.padding,
110            stride: params.stride,
111            dilation: params.dilation,
112        });
113        let out_dims = params.out_dims();
114        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
115    }
116
117    /// Applies a 1D convolution over the input tensor.
118    pub fn conv1d(
119        &self,
120        kernel: &Self,
121        padding: usize,
122        stride: usize,
123        dilation: usize,
124        groups: usize,
125    ) -> Result<Self> {
126        let (c_out, c_in_k, k_size) = kernel.dims3()?;
127        let (b_size, c_in, l_in) = self.dims3()?;
128        if c_in != c_in_k * groups {
129            Err(Error::Conv1dInvalidArgs {
130                inp_shape: self.shape().clone(),
131                k_shape: kernel.shape().clone(),
132                padding,
133                stride,
134                msg: "the number of in-channels on the input doesn't match the kernel size",
135            }
136            .bt())?
137        }
138
139        let params = ParamsConv1D {
140            b_size,
141            l_in,
142            c_out: c_out / groups,
143            c_in: c_in / groups,
144            k_size,
145            padding,
146            stride,
147            dilation,
148        };
149        if groups == 1 {
150            self.conv1d_single_group(kernel, &params)
151        } else {
152            let blocks = self.chunk(groups, 1)?;
153            let kernel = kernel.chunk(groups, 0)?;
154            let blocks = blocks
155                .iter()
156                .zip(&kernel)
157                .map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
158                .collect::<Result<Vec<_>>>()?;
159            Tensor::cat(&blocks, 1)
160        }
161    }
162
163    fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
164        let storage =
165            self.storage()
166                .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
167        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
168            arg,
169            kernel,
170            padding: params.padding,
171            stride: params.stride,
172            dilation: params.dilation,
173        });
174        let out_dims = params.out_dims();
175        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
176    }
177
178    /// Applies a 2D convolution over the input tensor.
179    pub fn conv2d(
180        &self,
181        kernel: &Self,
182        padding: usize,
183        stride: usize,
184        dilation: usize,
185        groups: usize,
186    ) -> Result<Self> {
187        let (b_size, c_in, i_h, i_w) = self.dims4()?;
188        let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
189        if c_in != c_in_k * groups {
190            crate::bail!(
191                "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
192            )
193        }
194        let params = ParamsConv2D {
195            b_size,
196            i_h,
197            i_w,
198            k_h,
199            k_w,
200            c_out: c_out / groups,
201            c_in: c_in / groups,
202            padding,
203            stride,
204            dilation,
205            cudnn_fwd_algo: None,
206        };
207        if groups == 1 {
208            self.conv2d_single_group(kernel, &params)
209        } else {
210            let blocks = self.chunk(groups, 1)?;
211            let kernel = kernel.chunk(groups, 0)?;
212            let blocks = blocks
213                .iter()
214                .zip(&kernel)
215                .map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
216                .collect::<Result<Vec<_>>>()?;
217            Tensor::cat(&blocks, 1)
218        }
219    }
220
221    /// Applies a 2D transposed convolution over the input tensor.
222    pub fn conv_transpose2d(
223        &self,
224        kernel: &Self,
225        padding: usize,
226        output_padding: usize,
227        stride: usize,
228        dilation: usize,
229    ) -> Result<Self> {
230        let (b_size, c_in, i_h, i_w) = self.dims4()?;
231        let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
232        if c_in != c_in_k {
233            crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
234        }
235        let params = ParamsConvTranspose2D {
236            b_size,
237            i_h,
238            i_w,
239            k_h,
240            k_w,
241            c_out,
242            c_in,
243            padding,
244            output_padding,
245            stride,
246            dilation,
247        };
248        let storage = self.storage().conv_transpose2d(
249            self.layout(),
250            &kernel.storage(),
251            kernel.layout(),
252            &params,
253        )?;
254        let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
255            arg,
256            kernel,
257            padding: params.padding,
258            output_padding: params.output_padding,
259            stride: params.stride,
260            dilation: params.dilation,
261        });
262        let out_dims = params.out_dims();
263        Ok(crate::tensor::from_storage(storage, out_dims, op, false))
264    }
265}