1use crate::BatchNorm;
3use candle::{conv::CudnnFwdAlgo, Result, Tensor};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub struct Conv1dConfig {
7 pub padding: usize,
8 pub stride: usize,
9 pub dilation: usize,
10 pub groups: usize,
11 pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
12}
13
14impl Default for Conv1dConfig {
15 fn default() -> Self {
16 Self {
17 padding: 0,
18 stride: 1,
19 dilation: 1,
20 groups: 1,
21 cudnn_fwd_algo: None,
22 }
23 }
24}
25
26#[derive(Clone, Debug)]
27pub struct Conv1d {
28 weight: Tensor,
29 bias: Option<Tensor>,
30 config: Conv1dConfig,
31}
32
33impl Conv1d {
34 pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv1dConfig) -> Self {
35 Self {
36 weight,
37 bias,
38 config,
39 }
40 }
41
42 pub fn config(&self) -> &Conv1dConfig {
43 &self.config
44 }
45
46 pub fn weight(&self) -> &Tensor {
47 &self.weight
48 }
49
50 pub fn bias(&self) -> Option<&Tensor> {
51 self.bias.as_ref()
52 }
53}
54
55impl crate::Module for Conv1d {
56 fn forward(&self, x: &Tensor) -> Result<Tensor> {
57 let x = x.conv1d_with_algo(
58 &self.weight,
59 self.config.padding,
60 self.config.stride,
61 self.config.dilation,
62 self.config.groups,
63 self.config.cudnn_fwd_algo,
64 )?;
65 match &self.bias {
66 None => Ok(x),
67 Some(bias) => {
68 let b = bias.dims1()?;
69 let bias = bias.reshape((1, b, 1))?;
70 Ok(x.broadcast_add(&bias)?)
71 }
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub struct ConvTranspose1dConfig {
78 pub padding: usize,
79 pub output_padding: usize,
80 pub stride: usize,
81 pub dilation: usize,
82 pub groups: usize,
83}
84
85impl Default for ConvTranspose1dConfig {
86 fn default() -> Self {
87 Self {
88 padding: 0,
89 output_padding: 0,
90 stride: 1,
91 dilation: 1,
92 groups: 1,
93 }
94 }
95}
96
97#[derive(Clone, Debug)]
98pub struct ConvTranspose1d {
99 weight: Tensor,
100 bias: Option<Tensor>,
101 config: ConvTranspose1dConfig,
102}
103
104impl ConvTranspose1d {
105 pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
106 Self {
107 weight,
108 bias,
109 config,
110 }
111 }
112
113 pub fn config(&self) -> &ConvTranspose1dConfig {
114 &self.config
115 }
116
117 pub fn weight(&self) -> &Tensor {
118 &self.weight
119 }
120
121 pub fn bias(&self) -> Option<&Tensor> {
122 self.bias.as_ref()
123 }
124}
125
126impl crate::Module for ConvTranspose1d {
127 fn forward(&self, x: &Tensor) -> Result<Tensor> {
128 let x = x.conv_transpose1d(
129 &self.weight,
130 self.config.padding,
131 self.config.output_padding,
132 self.config.stride,
133 self.config.dilation,
134 self.config.groups,
135 )?;
136 match &self.bias {
137 None => Ok(x),
138 Some(bias) => {
139 let b = bias.dims1()?;
140 let bias = bias.reshape((1, b, 1))?;
141 Ok(x.broadcast_add(&bias)?)
142 }
143 }
144 }
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub struct Conv2dConfig {
149 pub padding: usize,
150 pub stride: usize,
151 pub dilation: usize,
152 pub groups: usize,
153 pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
154}
155
156impl Default for Conv2dConfig {
157 fn default() -> Self {
158 Self {
159 padding: 0,
160 stride: 1,
161 dilation: 1,
162 groups: 1,
163 cudnn_fwd_algo: None,
164 }
165 }
166}
167
168#[derive(Clone, Debug)]
169pub struct Conv2d {
170 weight: Tensor,
171 bias: Option<Tensor>,
172 config: Conv2dConfig,
173}
174
175impl Conv2d {
176 pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) -> Self {
177 Self {
178 weight,
179 bias,
180 config,
181 }
182 }
183
184 pub fn config(&self) -> &Conv2dConfig {
185 &self.config
186 }
187
188 pub fn weight(&self) -> &Tensor {
189 &self.weight
190 }
191
192 pub fn bias(&self) -> Option<&Tensor> {
193 self.bias.as_ref()
194 }
195
196 pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
197 if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
198 let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
199 let weight = self
200 .weight()
201 .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
202 let bias = match &self.bias {
203 None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
204 Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
205 };
206 Ok(Self {
207 weight,
208 bias: Some(bias),
209 config: self.config,
210 })
211 } else {
212 candle::bail!("batch norm does not have weight_and_bias")
213 }
214 }
215}
216
217impl crate::Module for Conv2d {
218 fn forward(&self, x: &Tensor) -> Result<Tensor> {
219 let x = x.conv2d_with_algo(
220 &self.weight,
221 self.config.padding,
222 self.config.stride,
223 self.config.dilation,
224 self.config.groups,
225 self.config.cudnn_fwd_algo,
226 )?;
227 match &self.bias {
228 None => Ok(x),
229 Some(bias) => {
230 let b = bias.dims1()?;
231 let bias = bias.reshape((1, b, 1, 1))?;
232 Ok(x.broadcast_add(&bias)?)
233 }
234 }
235 }
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239pub struct ConvTranspose2dConfig {
240 pub padding: usize,
241 pub output_padding: usize,
242 pub stride: usize,
243 pub dilation: usize,
244 }
246
247impl Default for ConvTranspose2dConfig {
248 fn default() -> Self {
249 Self {
250 padding: 0,
251 output_padding: 0,
252 stride: 1,
253 dilation: 1,
254 }
255 }
256}
257
258#[derive(Clone, Debug)]
259pub struct ConvTranspose2d {
260 weight: Tensor,
261 bias: Option<Tensor>,
262 config: ConvTranspose2dConfig,
263}
264
265impl ConvTranspose2d {
266 pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self {
267 Self {
268 weight,
269 bias,
270 config,
271 }
272 }
273
274 pub fn config(&self) -> &ConvTranspose2dConfig {
275 &self.config
276 }
277
278 pub fn weight(&self) -> &Tensor {
279 &self.weight
280 }
281
282 pub fn bias(&self) -> Option<&Tensor> {
283 self.bias.as_ref()
284 }
285}
286
287impl crate::Module for ConvTranspose2d {
288 fn forward(&self, x: &Tensor) -> Result<Tensor> {
289 let x = x.conv_transpose2d(
290 &self.weight,
291 self.config.padding,
292 self.config.output_padding,
293 self.config.stride,
294 self.config.dilation,
295 )?;
296 match &self.bias {
297 None => Ok(x),
298 Some(bias) => {
299 let b = bias.dims1()?;
300 let bias = bias.reshape((1, b, 1, 1))?;
301 Ok(x.broadcast_add(&bias)?)
302 }
303 }
304 }
305}
306
307pub fn conv1d(
308 in_channels: usize,
309 out_channels: usize,
310 kernel_size: usize,
311 cfg: Conv1dConfig,
312 vb: crate::VarBuilder,
313) -> Result<Conv1d> {
314 let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
315 let ws = vb.get_with_hints(
316 (out_channels, in_channels / cfg.groups, kernel_size),
317 "weight",
318 init_ws,
319 )?;
320 let bound = 1. / (in_channels as f64).sqrt();
321 let init_bs = crate::Init::Uniform {
322 lo: -bound,
323 up: bound,
324 };
325 let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
326 Ok(Conv1d::new(ws, Some(bs), cfg))
327}
328
329pub fn conv1d_no_bias(
330 in_channels: usize,
331 out_channels: usize,
332 kernel_size: usize,
333 cfg: Conv1dConfig,
334 vb: crate::VarBuilder,
335) -> Result<Conv1d> {
336 let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
337 let ws = vb.get_with_hints(
338 (out_channels, in_channels / cfg.groups, kernel_size),
339 "weight",
340 init_ws,
341 )?;
342 Ok(Conv1d::new(ws, None, cfg))
343}
344
345pub fn conv_transpose1d(
346 in_channels: usize,
347 out_channels: usize,
348 kernel_size: usize,
349 cfg: ConvTranspose1dConfig,
350 vb: crate::VarBuilder,
351) -> Result<ConvTranspose1d> {
352 let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
353 let init = crate::Init::Uniform {
354 lo: -bound,
355 up: bound,
356 };
357 let ws = vb.get_with_hints(
358 (in_channels, out_channels / cfg.groups, kernel_size),
359 "weight",
360 init,
361 )?;
362 let bs = vb.get_with_hints(out_channels, "bias", init)?;
363 Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
364}
365
366pub fn conv_transpose1d_no_bias(
367 in_channels: usize,
368 out_channels: usize,
369 kernel_size: usize,
370 cfg: ConvTranspose1dConfig,
371 vb: crate::VarBuilder,
372) -> Result<ConvTranspose1d> {
373 let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
374 let init = crate::Init::Uniform {
375 lo: -bound,
376 up: bound,
377 };
378 let ws = vb.get_with_hints(
379 (in_channels, out_channels / cfg.groups, kernel_size),
380 "weight",
381 init,
382 )?;
383 Ok(ConvTranspose1d::new(ws, None, cfg))
384}
385
386pub fn conv2d(
387 in_channels: usize,
388 out_channels: usize,
389 kernel_size: usize,
390 cfg: Conv2dConfig,
391 vb: crate::VarBuilder,
392) -> Result<Conv2d> {
393 let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
394 let ws = vb.get_with_hints(
395 (
396 out_channels,
397 in_channels / cfg.groups,
398 kernel_size,
399 kernel_size,
400 ),
401 "weight",
402 init_ws,
403 )?;
404 let bound = 1. / (in_channels as f64).sqrt();
405 let init_bs = crate::Init::Uniform {
406 lo: -bound,
407 up: bound,
408 };
409 let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
410 Ok(Conv2d::new(ws, Some(bs), cfg))
411}
412
413pub fn conv2d_no_bias(
414 in_channels: usize,
415 out_channels: usize,
416 kernel_size: usize,
417 cfg: Conv2dConfig,
418 vb: crate::VarBuilder,
419) -> Result<Conv2d> {
420 let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
421 let ws = vb.get_with_hints(
422 (
423 out_channels,
424 in_channels / cfg.groups,
425 kernel_size,
426 kernel_size,
427 ),
428 "weight",
429 init_ws,
430 )?;
431 Ok(Conv2d::new(ws, None, cfg))
432}
433
434pub fn conv_transpose2d(
435 in_channels: usize,
436 out_channels: usize,
437 kernel_size: usize,
438 cfg: ConvTranspose2dConfig,
439 vb: crate::VarBuilder,
440) -> Result<ConvTranspose2d> {
441 let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
442 let init = crate::Init::Uniform {
443 lo: -bound,
444 up: bound,
445 };
446 let ws = vb.get_with_hints(
447 (in_channels, out_channels, kernel_size, kernel_size),
448 "weight",
449 init,
450 )?;
451 let bs = vb.get_with_hints(out_channels, "bias", init)?;
452 Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
453}
454
455pub fn conv_transpose2d_no_bias(
456 in_channels: usize,
457 out_channels: usize,
458 kernel_size: usize,
459 cfg: ConvTranspose2dConfig,
460 vb: crate::VarBuilder,
461) -> Result<ConvTranspose2d> {
462 let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
463 let init = crate::Init::Uniform {
464 lo: -bound,
465 up: bound,
466 };
467 let ws = vb.get_with_hints(
468 (in_channels, out_channels, kernel_size, kernel_size),
469 "weight",
470 init,
471 )?;
472 Ok(ConvTranspose2d::new(ws, None, cfg))
473}