1use crate::streaming::{StreamMask, StreamTensor, StreamingModule};
6use candle::{IndexOp, Module, Result, Tensor, D};
7use candle_nn::{Conv1d, VarBuilder};
8
9#[allow(clippy::enum_variant_names)]
10#[derive(Debug, Copy, Clone, PartialEq, Eq)]
11pub enum Norm {
12 WeightNorm,
13 SpectralNorm,
14 TimeGroupNorm,
15}
16
17#[derive(Debug, Copy, Clone, PartialEq, Eq)]
18pub enum PadMode {
19 Constant,
20 Reflect,
21 Replicate,
22}
23
24fn conv1d_weight_norm(
28 in_c: usize,
29 out_c: usize,
30 kernel_size: usize,
31 bias: bool,
32 config: candle_nn::Conv1dConfig,
33 vb: VarBuilder,
34) -> Result<Conv1d> {
35 let weight = if vb.contains_tensor("weight") {
36 vb.get((out_c, in_c, kernel_size), "weight")?
37 } else {
38 let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
39 let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
40 let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
41 weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
42 };
43 let bias = if bias { Some(vb.get(out_c, "bias")?) } else { None };
44 Ok(Conv1d::new(weight, bias, config))
45}
46
47#[derive(Debug, Clone)]
48pub struct NormConv1d {
49 conv: Conv1d,
50 norm: Option<candle_nn::GroupNorm>,
51 span: tracing::Span,
52}
53
54impl NormConv1d {
55 #[allow(clippy::too_many_arguments)]
56 pub fn new(
57 in_c: usize,
58 out_c: usize,
59 k_size: usize,
60 causal: bool,
61 norm: Option<Norm>,
62 bias: bool,
63 cfg: candle_nn::Conv1dConfig,
64 vb: VarBuilder,
65 ) -> Result<Self> {
66 let conv = match norm {
67 None | Some(Norm::TimeGroupNorm) => {
68 if bias {
69 candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
70 } else {
71 candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
72 }
73 }
74 Some(Norm::WeightNorm) => {
75 conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
76 }
77 Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
78 };
79 let norm = match norm {
80 None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
81 Some(Norm::TimeGroupNorm) => {
82 if causal {
83 candle::bail!("GroupNorm doesn't support causal evaluation.")
84 }
85 let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
86 Some(norm)
87 }
88 };
89 Ok(Self { conv, norm, span: tracing::span!(tracing::Level::TRACE, "norm-conv1d") })
90 }
91}
92
93impl Module for NormConv1d {
94 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
95 let _enter = self.span.enter();
96 let xs = xs.apply(&self.conv)?;
97 match self.norm.as_ref() {
98 None => Ok(xs),
99 Some(norm) => xs.apply(norm),
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct NormConvTranspose1d {
106 ws: Tensor,
107 bs: Option<Tensor>,
108 k_size: usize,
109 stride: usize,
110 groups: usize,
111 norm: Option<candle_nn::GroupNorm>,
112 span: tracing::Span,
113}
114
115impl NormConvTranspose1d {
116 #[allow(clippy::too_many_arguments)]
117 pub fn new(
118 in_c: usize,
119 out_c: usize,
120 k_size: usize,
121 causal: bool,
122 norm: Option<Norm>,
123 bias: bool,
124 stride: usize,
125 groups: usize,
126 vb: VarBuilder,
127 ) -> Result<Self> {
128 let vb = vb.pp("convtr");
129 let bs = if bias { Some(vb.get(out_c, "bias")?) } else { None };
130 let ws = match norm {
131 None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
132 Some(Norm::WeightNorm) => {
133 if vb.contains_tensor("weight") {
134 vb.get((in_c, out_c, k_size), "weight")?
135 } else {
136 let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
137 let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
138 let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
139 weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
140 }
141 }
142 Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
143 };
144 let (ws, groups) = if groups == out_c && in_c == out_c {
145 let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
146 let ws = ws.repeat((1, out_c, 1))?.mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
147 (ws, 1)
148 } else {
149 (ws, groups)
150 };
151 let norm = match norm {
152 None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
153 Some(Norm::TimeGroupNorm) => {
154 if causal {
155 candle::bail!("GroupNorm doesn't support causal evaluation.")
156 }
157 let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
158 Some(norm)
159 }
160 };
161 Ok(Self {
162 ws,
163 bs,
164 k_size,
165 stride,
166 groups,
167 norm,
168 span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
169 })
170 }
171}
172
173impl Module for NormConvTranspose1d {
174 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
175 let _enter = self.span.enter();
176 let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
182 let xs = match &self.bs {
183 None => xs,
184 Some(bias) => {
185 let b = bias.dims1()?;
186 let bias = bias.reshape((1, b, 1))?;
187 xs.broadcast_add(&bias)?
188 }
189 };
190 match self.norm.as_ref() {
191 None => Ok(xs),
192 Some(norm) => xs.apply(norm),
193 }
194 }
195}
196
197fn get_extra_padding_for_conv1d(
198 xs: &Tensor,
199 k_size: usize,
200 stride: usize,
201 padding_total: usize,
202) -> Result<usize> {
203 let len = xs.dim(D::Minus1)?;
204 let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
205 let ideal_len =
206 ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
207 Ok(ideal_len.saturating_sub(len))
208}
209
210fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
211 match mode {
212 PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
213 PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
214 PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
215 }
216}
217
218fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
219 let len = xs.dim(D::Minus1)?;
220 if len < unpad_l + unpad_r {
221 candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
222 }
223 xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
224}
225
226#[derive(Debug, Clone)]
227pub struct StreamableConv1d {
228 conv: NormConv1d,
229 causal: bool,
230 pad_mode: PadMode,
231 state_prev_xs: StreamTensor,
232 left_pad_applied: bool,
233 kernel_size: usize,
234 span: tracing::Span,
235}
236
237impl StreamableConv1d {
238 #[allow(clippy::too_many_arguments)]
239 pub fn new(
240 in_c: usize,
241 out_c: usize,
242 k_size: usize,
243 stride: usize,
244 dilation: usize,
245 groups: usize,
246 bias: bool,
247 causal: bool,
248 norm: Option<Norm>,
249 pad_mode: PadMode,
250 vb: VarBuilder,
251 ) -> Result<Self> {
252 let cfg = candle_nn::Conv1dConfig {
253 padding: 0,
254 stride,
255 dilation,
256 groups,
257 cudnn_fwd_algo: Some(candle::conv::CudnnFwdAlgo::ImplicitGemm),
258 };
259 let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb.pp("conv"))?;
260 if k_size < stride {
261 candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
262 }
263 Ok(Self {
264 conv,
265 causal,
266 pad_mode,
267 state_prev_xs: StreamTensor::empty(),
268 left_pad_applied: false,
269 kernel_size: k_size,
270 span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
271 })
272 }
273
274 pub fn reset_batch_idx(&mut self, batch_idx: usize, _batch_size: usize) -> Result<()> {
275 if let Some(v) = self.state_prev_xs.as_option() {
276 let v = v.contiguous()?;
277 v.i(batch_idx..(1 + batch_idx))?.zero_set()?;
278 self.state_prev_xs = StreamTensor::from_tensor(v);
279 }
280 Ok(())
281 }
282}
283
284impl Module for StreamableConv1d {
285 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
286 let _enter = self.span.enter();
287 let (_b, _t, _c) = xs.dims3()?;
288 let k_size = self.conv.conv.weight().dim(D::Minus1)?;
289 let conv_cfg = self.conv.conv.config();
290 let k_size = (k_size - 1) * conv_cfg.dilation + 1;
292 let padding_total = k_size - conv_cfg.stride;
293 let extra_padding =
294 get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
295 let xs = if self.causal {
296 pad1d(xs, padding_total, extra_padding, self.pad_mode)?
297 } else {
298 let padding_right = padding_total / 2;
299 let padding_left = padding_total - padding_right;
300 pad1d(xs, padding_left, padding_right + extra_padding, self.pad_mode)?
301 };
302 xs.apply(&self.conv)
303 }
304}
305
306impl StreamingModule for StreamableConv1d {
307 fn reset_state(&mut self) {
308 self.state_prev_xs.reset();
309 self.left_pad_applied = false;
310 }
311
312 fn step(&mut self, xs: &StreamTensor, mask: &StreamMask) -> Result<StreamTensor> {
313 let _enter = self.span.enter();
314 let xs = match xs.as_option() {
315 None => return Ok(().into()),
316 Some(xs) => xs.clone(),
317 };
318 let xs = if self.left_pad_applied {
319 xs
320 } else {
321 self.left_pad_applied = true;
322 let k_size = self.conv.conv.weight().dim(D::Minus1)?;
323 let conv_cfg = self.conv.conv.config();
324 let k_size = (k_size - 1) * conv_cfg.dilation + 1;
325 let padding_total = k_size - conv_cfg.stride;
326 pad1d(&xs, padding_total, 0, self.pad_mode)?
327 };
328 let cfg = self.conv.conv.config();
329 let stride = cfg.stride;
330 let dilation = cfg.dilation;
331 let kernel = (self.kernel_size - 1) * dilation + 1;
332 let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
333 let seq_len = xs.seq_len(D::Minus1)?;
334 let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
335 let (state_prev_xs, ys) = if num_frames > 0 {
336 let offset = num_frames * stride;
337 let state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
338 let in_l = (num_frames - 1) * stride + kernel;
339 let xs = xs.narrow(D::Minus1, 0, in_l)?;
340 let ys = xs.apply(&self.conv.conv)?;
343 (state_prev_xs, ys)
344 } else {
345 (xs, StreamTensor::empty())
346 };
347 let state_prev_xs = match mask.as_option() {
348 None => state_prev_xs,
349 Some(mask) => match (state_prev_xs.as_option(), self.state_prev_xs.as_option()) {
350 (None, None) => state_prev_xs,
351 (Some(state_prev_xs), None) => {
352 let z = state_prev_xs.zeros_like()?;
353 let mask = mask.reshape(((), 1, 1))?.broadcast_as(state_prev_xs.shape())?;
354 mask.where_cond(state_prev_xs, &z)?.into()
355 }
356 (None, Some(_)) => {
357 candle::bail!("streaming conv1d should only be used with constant steps")
358 }
359 (Some(prev_xs), Some(prev_prev_xs)) => {
360 if prev_xs.shape() != prev_prev_xs.shape() {
361 candle::bail!("streaming conv1d should only be used with constant steps {prev_xs:?} {prev_prev_xs:?}")
362 }
363 let mask = mask.reshape(((), 1, 1))?.broadcast_as(prev_xs.shape())?;
364 mask.where_cond(prev_xs, prev_prev_xs)?.into()
365 }
366 },
367 };
368 self.state_prev_xs = state_prev_xs;
369 Ok(ys)
370 }
371}
372
373#[derive(Debug, Clone)]
374pub struct StreamableConvTranspose1d {
375 convtr: NormConvTranspose1d,
376 causal: bool,
377 state_prev_ys: StreamTensor,
378 kernel_size: usize,
379 span: tracing::Span,
380}
381
382impl StreamableConvTranspose1d {
383 #[allow(clippy::too_many_arguments)]
384 pub fn new(
385 in_c: usize,
386 out_c: usize,
387 k_size: usize,
388 stride: usize,
389 groups: usize,
390 bias: bool,
391 causal: bool,
392 norm: Option<Norm>,
393 vb: VarBuilder,
394 ) -> Result<Self> {
395 let convtr = NormConvTranspose1d::new(
396 in_c,
397 out_c,
398 k_size,
399 causal,
400 norm,
401 bias,
402 stride,
403 groups,
404 vb.pp("convtr"),
405 )?;
406 Ok(Self {
407 convtr,
408 causal,
409 kernel_size: k_size,
410 state_prev_ys: StreamTensor::empty(),
411 span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
412 })
413 }
414
415 pub fn reset_batch_idx(&mut self, batch_idx: usize, _batch_size: usize) -> Result<()> {
416 if let Some(v) = self.state_prev_ys.as_option() {
417 let v = v.contiguous()?;
418 v.i(batch_idx..(1 + batch_idx))?.zero_set()?;
419 self.state_prev_ys = v.into();
420 }
421 Ok(())
422 }
423}
424
425impl Module for StreamableConvTranspose1d {
426 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
427 let _enter = self.span.enter();
428 let k_size = self.convtr.k_size;
429 let stride = self.convtr.stride;
430 let padding_total = k_size.saturating_sub(stride);
431 let xs = xs.apply(&self.convtr)?;
432 if self.causal {
433 unpad1d(&xs, 0, padding_total)
435 } else {
436 let padding_right = padding_total / 2;
437 let padding_left = padding_total - padding_right;
438 unpad1d(&xs, padding_left, padding_right)
439 }
440 }
441}
442
443impl StreamingModule for StreamableConvTranspose1d {
444 fn reset_state(&mut self) {
445 self.state_prev_ys.reset()
446 }
447
448 fn step(&mut self, xs: &StreamTensor, mask: &StreamMask) -> Result<StreamTensor> {
449 let _enter = self.span.enter();
450 let xs = match xs.as_option() {
451 Some(xs) => xs,
452 None => return Ok(StreamTensor::empty()),
453 };
454 let stride = self.convtr.stride;
455 let ys = self.convtr.forward(xs)?;
458 let ot = ys.dim(D::Minus1)?;
459 let ys = match self.state_prev_ys.as_option() {
460 None => ys,
461 Some(prev_ys) => {
462 let pt = prev_ys.dim(D::Minus1)?;
463 let prev_ys = match &self.convtr.bs {
465 None => prev_ys.clone(),
466 Some(bias) => {
467 let bias = bias.reshape((1, (), 1))?;
468 prev_ys.broadcast_sub(&bias)?
469 }
470 };
471 let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
472 let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
473 Tensor::cat(&[ys1, ys2], D::Minus1)?
474 }
475 };
476 let invalid_steps = self.kernel_size - stride;
477 let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
478 let prev_ys = match mask.as_option() {
479 None => prev_ys,
480 Some(mask) => match (prev_ys.as_option(), self.state_prev_ys.as_option()) {
481 (None, None) => prev_ys,
482 (Some(prev_ys), None) => {
483 let z = prev_ys.zeros_like()?;
484 let mask = mask.reshape(((), 1, 1))?.broadcast_as(prev_ys.shape())?;
485 mask.where_cond(prev_ys, &z)?.into()
486 }
487 (None, Some(_)) => {
488 candle::bail!("streaming conv-tr1d should only be used with constant steps")
489 }
490 (Some(prev_ys), Some(prev_prev_ys)) => {
491 if prev_ys.shape() != prev_prev_ys.shape() {
492 candle::bail!("streaming conv-tr1d should only be used with constant steps {prev_ys:?} {prev_prev_ys:?}")
493 }
494 let mask = mask.reshape(((), 1, 1))?.broadcast_as(prev_ys.shape())?;
495 mask.where_cond(prev_ys, prev_prev_ys)?.into()
496 }
497 },
498 };
499 self.state_prev_ys = prev_ys;
500 Ok(ys)
501 }
502}
503
504#[derive(Debug, Clone)]
505pub struct ConvDownsample1d {
506 conv: StreamableConv1d,
507}
508
509impl ConvDownsample1d {
510 pub fn new(
511 stride: usize,
512 dim: usize,
513 causal: bool,
514 learnt: bool,
515 vb: VarBuilder,
516 ) -> Result<Self> {
517 if !learnt {
518 candle::bail!("only learnt=true is supported")
519 }
520 let conv = StreamableConv1d::new(
521 dim,
522 dim,
523 2 * stride,
524 stride,
525 1,
526 1, false,
528 causal,
529 None,
530 PadMode::Replicate,
531 vb.pp("conv"),
532 )?;
533 Ok(Self { conv })
534 }
535
536 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
537 self.conv.reset_batch_idx(batch_idx, batch_size)
538 }
539}
540
541impl Module for ConvDownsample1d {
542 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
543 xs.apply(&self.conv)
544 }
545}
546
547impl StreamingModule for ConvDownsample1d {
548 fn reset_state(&mut self) {
549 self.conv.reset_state()
550 }
551
552 fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
553 self.conv.step(xs, m)
554 }
555}
556
557#[derive(Debug, Clone)]
558pub struct ConvTrUpsample1d {
559 convtr: StreamableConvTranspose1d,
560}
561
562impl ConvTrUpsample1d {
563 pub fn new(
564 stride: usize,
565 dim: usize,
566 causal: bool,
567 learnt: bool,
568 vb: VarBuilder,
569 ) -> Result<Self> {
570 if !learnt {
571 candle::bail!("only learnt=true is supported")
572 }
573 let convtr = StreamableConvTranspose1d::new(
574 dim,
575 dim,
576 2 * stride,
577 stride,
578 dim,
579 false,
580 causal,
581 None,
582 vb.pp("convtr"),
583 )?;
584 Ok(Self { convtr })
585 }
586
587 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
588 self.convtr.reset_batch_idx(batch_idx, batch_size)
589 }
590}
591
592impl Module for ConvTrUpsample1d {
593 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
594 xs.apply(&self.convtr)
595 }
596}
597
598impl StreamingModule for ConvTrUpsample1d {
599 fn reset_state(&mut self) {
600 self.convtr.reset_state()
601 }
602
603 fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
604 self.convtr.step(xs, m)
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use candle::IndexOp;
612
613 fn run_conv1d(
614 k_size: usize,
615 stride: usize,
616 dilation: usize,
617 step_size: usize,
618 len: usize,
619 bias: bool,
620 ) -> Result<()> {
621 let dev = &candle::Device::Cpu;
623 let vm = candle_nn::VarMap::new();
624 let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
625 let conv1d = StreamableConv1d::new(
626 2,
627 3,
628 k_size,
629 stride,
630 dilation,
631 1,
632 bias,
633 true,
634 None,
635 PadMode::Constant,
636 vb,
637 )?;
638 let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
639 let ys = conv1d.forward(&xs)?;
640 let mut conv1d = conv1d;
641 let mut ys_steps = vec![];
642 for idx in 0..len {
643 let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
644 let ys = conv1d.step(&xs.into(), &().into())?;
645 if let Some(ys) = ys.as_option() {
646 ys_steps.push(ys.clone())
647 }
648 }
649 let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
650 let diff = (&ys - &ys_steps)?.abs()?.flatten_all()?.max(0)?.to_vec0::<f32>()?;
651 if diff > 1e-5 {
652 println!("{xs}");
653 println!("{ys}");
654 println!("{ys_steps}");
655 candle::bail!("larger diff than expected {diff}")
656 }
657 Ok(())
658 }
659
660 fn run_conv_tr1d(
661 k_size: usize,
662 stride: usize,
663 step_size: usize,
664 len: usize,
665 bias: bool,
666 ) -> Result<()> {
667 let dev = &candle::Device::Cpu;
669 let vm = candle_nn::VarMap::new();
670 let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
671 let conv1d = StreamableConvTranspose1d::new(
672 2, 3, k_size,
673 stride, 1, bias,
674 true, None, vb,
675 )?;
676 let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
677 let ys = conv1d.forward(&xs)?;
678 let mut conv1d = conv1d;
679 let mut ys_steps = vec![];
680 for idx in 0..len {
681 let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
682 let ys = conv1d.step(&xs.into(), &().into())?;
683 if let Some(ys) = ys.as_option() {
684 ys_steps.push(ys.clone())
685 }
686 }
687 let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
688 let diff = (&ys - &ys_steps)?.abs()?.flatten_all()?.max(0)?.to_vec0::<f32>()?;
689 if diff > 1e-5 {
690 println!("{xs}");
691 println!("{ys}");
692 println!("{ys_steps}");
693 candle::bail!("larger diff than expected {diff}")
694 }
695 Ok(())
696 }
697
698 #[test]
699 fn conv1d() -> Result<()> {
700 for step_size in [1, 2, 3] {
701 for bias in [false, true] {
702 run_conv1d(1, 1, 1, step_size, 5, bias)?;
703 run_conv1d(2, 1, 1, step_size, 5, bias)?;
704 run_conv1d(2, 2, 1, step_size, 6, bias)?;
705 run_conv1d(3, 2, 1, step_size, 8, bias)?;
706 run_conv1d(3, 2, 2, step_size, 8, bias)?;
707 }
708 }
709 Ok(())
710 }
711
712 #[test]
713 fn conv_tr1d() -> Result<()> {
714 for step_size in [1, 2, 3] {
715 for bias in [false, true] {
716 run_conv_tr1d(1, 1, step_size, 5, bias)?;
717 run_conv_tr1d(2, 1, step_size, 5, bias)?;
718 run_conv_tr1d(3, 1, step_size, 5, bias)?;
719 run_conv_tr1d(3, 2, step_size, 5, bias)?;
720 }
721 }
722 Ok(())
723 }
724}