candle_nn/
conv.rs

1//! Convolution Layers.
2use crate::BatchNorm;
3use candle::{conv::CudnnFwdAlgo, Result, Tensor};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub struct Conv1dConfig {
7    pub padding: usize,
8    pub stride: usize,
9    pub dilation: usize,
10    pub groups: usize,
11    pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
12}
13
14impl Default for Conv1dConfig {
15    fn default() -> Self {
16        Self {
17            padding: 0,
18            stride: 1,
19            dilation: 1,
20            groups: 1,
21            cudnn_fwd_algo: None,
22        }
23    }
24}
25
26#[derive(Clone, Debug)]
27pub struct Conv1d {
28    weight: Tensor,
29    bias: Option<Tensor>,
30    config: Conv1dConfig,
31}
32
33impl Conv1d {
34    pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv1dConfig) -> Self {
35        Self {
36            weight,
37            bias,
38            config,
39        }
40    }
41
42    pub fn config(&self) -> &Conv1dConfig {
43        &self.config
44    }
45
46    pub fn weight(&self) -> &Tensor {
47        &self.weight
48    }
49
50    pub fn bias(&self) -> Option<&Tensor> {
51        self.bias.as_ref()
52    }
53}
54
55impl crate::Module for Conv1d {
56    fn forward(&self, x: &Tensor) -> Result<Tensor> {
57        let x = x.conv1d_with_algo(
58            &self.weight,
59            self.config.padding,
60            self.config.stride,
61            self.config.dilation,
62            self.config.groups,
63            self.config.cudnn_fwd_algo,
64        )?;
65        match &self.bias {
66            None => Ok(x),
67            Some(bias) => {
68                let b = bias.dims1()?;
69                let bias = bias.reshape((1, b, 1))?;
70                Ok(x.broadcast_add(&bias)?)
71            }
72        }
73    }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub struct ConvTranspose1dConfig {
78    pub padding: usize,
79    pub output_padding: usize,
80    pub stride: usize,
81    pub dilation: usize,
82    pub groups: usize,
83}
84
85impl Default for ConvTranspose1dConfig {
86    fn default() -> Self {
87        Self {
88            padding: 0,
89            output_padding: 0,
90            stride: 1,
91            dilation: 1,
92            groups: 1,
93        }
94    }
95}
96
97#[derive(Clone, Debug)]
98pub struct ConvTranspose1d {
99    weight: Tensor,
100    bias: Option<Tensor>,
101    config: ConvTranspose1dConfig,
102}
103
104impl ConvTranspose1d {
105    pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
106        Self {
107            weight,
108            bias,
109            config,
110        }
111    }
112
113    pub fn config(&self) -> &ConvTranspose1dConfig {
114        &self.config
115    }
116
117    pub fn weight(&self) -> &Tensor {
118        &self.weight
119    }
120
121    pub fn bias(&self) -> Option<&Tensor> {
122        self.bias.as_ref()
123    }
124}
125
126impl crate::Module for ConvTranspose1d {
127    fn forward(&self, x: &Tensor) -> Result<Tensor> {
128        let x = x.conv_transpose1d(
129            &self.weight,
130            self.config.padding,
131            self.config.output_padding,
132            self.config.stride,
133            self.config.dilation,
134            self.config.groups,
135        )?;
136        match &self.bias {
137            None => Ok(x),
138            Some(bias) => {
139                let b = bias.dims1()?;
140                let bias = bias.reshape((1, b, 1))?;
141                Ok(x.broadcast_add(&bias)?)
142            }
143        }
144    }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub struct Conv2dConfig {
149    pub padding: usize,
150    pub stride: usize,
151    pub dilation: usize,
152    pub groups: usize,
153    pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
154}
155
156impl Default for Conv2dConfig {
157    fn default() -> Self {
158        Self {
159            padding: 0,
160            stride: 1,
161            dilation: 1,
162            groups: 1,
163            cudnn_fwd_algo: None,
164        }
165    }
166}
167
168#[derive(Clone, Debug)]
169pub struct Conv2d {
170    weight: Tensor,
171    bias: Option<Tensor>,
172    config: Conv2dConfig,
173}
174
175impl Conv2d {
176    pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) -> Self {
177        Self {
178            weight,
179            bias,
180            config,
181        }
182    }
183
184    pub fn config(&self) -> &Conv2dConfig {
185        &self.config
186    }
187
188    pub fn weight(&self) -> &Tensor {
189        &self.weight
190    }
191
192    pub fn bias(&self) -> Option<&Tensor> {
193        self.bias.as_ref()
194    }
195
196    pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
197        if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
198            let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
199            let weight = self
200                .weight()
201                .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
202            let bias = match &self.bias {
203                None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
204                Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
205            };
206            Ok(Self {
207                weight,
208                bias: Some(bias),
209                config: self.config,
210            })
211        } else {
212            candle::bail!("batch norm does not have weight_and_bias")
213        }
214    }
215}
216
217impl crate::Module for Conv2d {
218    fn forward(&self, x: &Tensor) -> Result<Tensor> {
219        let x = x.conv2d_with_algo(
220            &self.weight,
221            self.config.padding,
222            self.config.stride,
223            self.config.dilation,
224            self.config.groups,
225            self.config.cudnn_fwd_algo,
226        )?;
227        match &self.bias {
228            None => Ok(x),
229            Some(bias) => {
230                let b = bias.dims1()?;
231                let bias = bias.reshape((1, b, 1, 1))?;
232                Ok(x.broadcast_add(&bias)?)
233            }
234        }
235    }
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239pub struct ConvTranspose2dConfig {
240    pub padding: usize,
241    pub output_padding: usize,
242    pub stride: usize,
243    pub dilation: usize,
244    // TODO: support groups.
245}
246
247impl Default for ConvTranspose2dConfig {
248    fn default() -> Self {
249        Self {
250            padding: 0,
251            output_padding: 0,
252            stride: 1,
253            dilation: 1,
254        }
255    }
256}
257
258#[derive(Clone, Debug)]
259pub struct ConvTranspose2d {
260    weight: Tensor,
261    bias: Option<Tensor>,
262    config: ConvTranspose2dConfig,
263}
264
265impl ConvTranspose2d {
266    pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self {
267        Self {
268            weight,
269            bias,
270            config,
271        }
272    }
273
274    pub fn config(&self) -> &ConvTranspose2dConfig {
275        &self.config
276    }
277
278    pub fn weight(&self) -> &Tensor {
279        &self.weight
280    }
281
282    pub fn bias(&self) -> Option<&Tensor> {
283        self.bias.as_ref()
284    }
285}
286
287impl crate::Module for ConvTranspose2d {
288    fn forward(&self, x: &Tensor) -> Result<Tensor> {
289        let x = x.conv_transpose2d(
290            &self.weight,
291            self.config.padding,
292            self.config.output_padding,
293            self.config.stride,
294            self.config.dilation,
295        )?;
296        match &self.bias {
297            None => Ok(x),
298            Some(bias) => {
299                let b = bias.dims1()?;
300                let bias = bias.reshape((1, b, 1, 1))?;
301                Ok(x.broadcast_add(&bias)?)
302            }
303        }
304    }
305}
306
307pub fn conv1d(
308    in_channels: usize,
309    out_channels: usize,
310    kernel_size: usize,
311    cfg: Conv1dConfig,
312    vb: crate::VarBuilder,
313) -> Result<Conv1d> {
314    let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
315    let ws = vb.get_with_hints(
316        (out_channels, in_channels / cfg.groups, kernel_size),
317        "weight",
318        init_ws,
319    )?;
320    let bound = 1. / (in_channels as f64).sqrt();
321    let init_bs = crate::Init::Uniform {
322        lo: -bound,
323        up: bound,
324    };
325    let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
326    Ok(Conv1d::new(ws, Some(bs), cfg))
327}
328
329pub fn conv1d_no_bias(
330    in_channels: usize,
331    out_channels: usize,
332    kernel_size: usize,
333    cfg: Conv1dConfig,
334    vb: crate::VarBuilder,
335) -> Result<Conv1d> {
336    let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
337    let ws = vb.get_with_hints(
338        (out_channels, in_channels / cfg.groups, kernel_size),
339        "weight",
340        init_ws,
341    )?;
342    Ok(Conv1d::new(ws, None, cfg))
343}
344
345pub fn conv_transpose1d(
346    in_channels: usize,
347    out_channels: usize,
348    kernel_size: usize,
349    cfg: ConvTranspose1dConfig,
350    vb: crate::VarBuilder,
351) -> Result<ConvTranspose1d> {
352    let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
353    let init = crate::Init::Uniform {
354        lo: -bound,
355        up: bound,
356    };
357    let ws = vb.get_with_hints(
358        (in_channels, out_channels / cfg.groups, kernel_size),
359        "weight",
360        init,
361    )?;
362    let bs = vb.get_with_hints(out_channels, "bias", init)?;
363    Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
364}
365
366pub fn conv_transpose1d_no_bias(
367    in_channels: usize,
368    out_channels: usize,
369    kernel_size: usize,
370    cfg: ConvTranspose1dConfig,
371    vb: crate::VarBuilder,
372) -> Result<ConvTranspose1d> {
373    let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
374    let init = crate::Init::Uniform {
375        lo: -bound,
376        up: bound,
377    };
378    let ws = vb.get_with_hints(
379        (in_channels, out_channels / cfg.groups, kernel_size),
380        "weight",
381        init,
382    )?;
383    Ok(ConvTranspose1d::new(ws, None, cfg))
384}
385
386pub fn conv2d(
387    in_channels: usize,
388    out_channels: usize,
389    kernel_size: usize,
390    cfg: Conv2dConfig,
391    vb: crate::VarBuilder,
392) -> Result<Conv2d> {
393    let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
394    let ws = vb.get_with_hints(
395        (
396            out_channels,
397            in_channels / cfg.groups,
398            kernel_size,
399            kernel_size,
400        ),
401        "weight",
402        init_ws,
403    )?;
404    let bound = 1. / (in_channels as f64).sqrt();
405    let init_bs = crate::Init::Uniform {
406        lo: -bound,
407        up: bound,
408    };
409    let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
410    Ok(Conv2d::new(ws, Some(bs), cfg))
411}
412
413pub fn conv2d_no_bias(
414    in_channels: usize,
415    out_channels: usize,
416    kernel_size: usize,
417    cfg: Conv2dConfig,
418    vb: crate::VarBuilder,
419) -> Result<Conv2d> {
420    let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
421    let ws = vb.get_with_hints(
422        (
423            out_channels,
424            in_channels / cfg.groups,
425            kernel_size,
426            kernel_size,
427        ),
428        "weight",
429        init_ws,
430    )?;
431    Ok(Conv2d::new(ws, None, cfg))
432}
433
434pub fn conv_transpose2d(
435    in_channels: usize,
436    out_channels: usize,
437    kernel_size: usize,
438    cfg: ConvTranspose2dConfig,
439    vb: crate::VarBuilder,
440) -> Result<ConvTranspose2d> {
441    let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
442    let init = crate::Init::Uniform {
443        lo: -bound,
444        up: bound,
445    };
446    let ws = vb.get_with_hints(
447        (in_channels, out_channels, kernel_size, kernel_size),
448        "weight",
449        init,
450    )?;
451    let bs = vb.get_with_hints(out_channels, "bias", init)?;
452    Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
453}
454
455pub fn conv_transpose2d_no_bias(
456    in_channels: usize,
457    out_channels: usize,
458    kernel_size: usize,
459    cfg: ConvTranspose2dConfig,
460    vb: crate::VarBuilder,
461) -> Result<ConvTranspose2d> {
462    let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
463    let init = crate::Init::Uniform {
464        lo: -bound,
465        up: bound,
466    };
467    let ws = vb.get_with_hints(
468        (in_channels, out_channels, kernel_size, kernel_size),
469        "weight",
470        init,
471    )?;
472    Ok(ConvTranspose2d::new(ws, None, cfg))
473}