moshi_db/
conv.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use crate::streaming::{StreamMask, StreamTensor, StreamingModule};
6use candle::{IndexOp, Module, Result, Tensor, D};
7use candle_nn::{Conv1d, VarBuilder};
8
9#[allow(clippy::enum_variant_names)]
10#[derive(Debug, Copy, Clone, PartialEq, Eq)]
11pub enum Norm {
12    WeightNorm,
13    SpectralNorm,
14    TimeGroupNorm,
15}
16
17#[derive(Debug, Copy, Clone, PartialEq, Eq)]
18pub enum PadMode {
19    Constant,
20    Reflect,
21    Replicate,
22}
23
24// Applies weight norm for inference by recomputing the weight tensor. This
25// does not apply to training.
26// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
27fn conv1d_weight_norm(
28    in_c: usize,
29    out_c: usize,
30    kernel_size: usize,
31    bias: bool,
32    config: candle_nn::Conv1dConfig,
33    vb: VarBuilder,
34) -> Result<Conv1d> {
35    let weight = if vb.contains_tensor("weight") {
36        vb.get((out_c, in_c, kernel_size), "weight")?
37    } else {
38        let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
39        let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
40        let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
41        weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
42    };
43    let bias = if bias { Some(vb.get(out_c, "bias")?) } else { None };
44    Ok(Conv1d::new(weight, bias, config))
45}
46
47#[derive(Debug, Clone)]
48pub struct NormConv1d {
49    conv: Conv1d,
50    norm: Option<candle_nn::GroupNorm>,
51    span: tracing::Span,
52}
53
54impl NormConv1d {
55    #[allow(clippy::too_many_arguments)]
56    pub fn new(
57        in_c: usize,
58        out_c: usize,
59        k_size: usize,
60        causal: bool,
61        norm: Option<Norm>,
62        bias: bool,
63        cfg: candle_nn::Conv1dConfig,
64        vb: VarBuilder,
65    ) -> Result<Self> {
66        let conv = match norm {
67            None | Some(Norm::TimeGroupNorm) => {
68                if bias {
69                    candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
70                } else {
71                    candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
72                }
73            }
74            Some(Norm::WeightNorm) => {
75                conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
76            }
77            Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
78        };
79        let norm = match norm {
80            None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
81            Some(Norm::TimeGroupNorm) => {
82                if causal {
83                    candle::bail!("GroupNorm doesn't support causal evaluation.")
84                }
85                let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
86                Some(norm)
87            }
88        };
89        Ok(Self { conv, norm, span: tracing::span!(tracing::Level::TRACE, "norm-conv1d") })
90    }
91}
92
93impl Module for NormConv1d {
94    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
95        let _enter = self.span.enter();
96        let xs = xs.apply(&self.conv)?;
97        match self.norm.as_ref() {
98            None => Ok(xs),
99            Some(norm) => xs.apply(norm),
100        }
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct NormConvTranspose1d {
106    ws: Tensor,
107    bs: Option<Tensor>,
108    k_size: usize,
109    stride: usize,
110    groups: usize,
111    norm: Option<candle_nn::GroupNorm>,
112    span: tracing::Span,
113}
114
115impl NormConvTranspose1d {
116    #[allow(clippy::too_many_arguments)]
117    pub fn new(
118        in_c: usize,
119        out_c: usize,
120        k_size: usize,
121        causal: bool,
122        norm: Option<Norm>,
123        bias: bool,
124        stride: usize,
125        groups: usize,
126        vb: VarBuilder,
127    ) -> Result<Self> {
128        let vb = vb.pp("convtr");
129        let bs = if bias { Some(vb.get(out_c, "bias")?) } else { None };
130        let ws = match norm {
131            None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
132            Some(Norm::WeightNorm) => {
133                if vb.contains_tensor("weight") {
134                    vb.get((in_c, out_c, k_size), "weight")?
135                } else {
136                    let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
137                    let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
138                    let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
139                    weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
140                }
141            }
142            Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
143        };
144        let (ws, groups) = if groups == out_c && in_c == out_c {
145            let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
146            let ws = ws.repeat((1, out_c, 1))?.mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
147            (ws, 1)
148        } else {
149            (ws, groups)
150        };
151        let norm = match norm {
152            None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
153            Some(Norm::TimeGroupNorm) => {
154                if causal {
155                    candle::bail!("GroupNorm doesn't support causal evaluation.")
156                }
157                let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
158                Some(norm)
159            }
160        };
161        Ok(Self {
162            ws,
163            bs,
164            k_size,
165            stride,
166            groups,
167            norm,
168            span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
169        })
170    }
171}
172
173impl Module for NormConvTranspose1d {
174    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
175        let _enter = self.span.enter();
176        // conv-transpose1d seems to be broken on metal after enough iterations. Causing
177        // the following error:
178        // _status < MTLCommandBufferStatusCommitted >
179        // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
180        // This is now fixed in candle.
181        let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
182        let xs = match &self.bs {
183            None => xs,
184            Some(bias) => {
185                let b = bias.dims1()?;
186                let bias = bias.reshape((1, b, 1))?;
187                xs.broadcast_add(&bias)?
188            }
189        };
190        match self.norm.as_ref() {
191            None => Ok(xs),
192            Some(norm) => xs.apply(norm),
193        }
194    }
195}
196
197fn get_extra_padding_for_conv1d(
198    xs: &Tensor,
199    k_size: usize,
200    stride: usize,
201    padding_total: usize,
202) -> Result<usize> {
203    let len = xs.dim(D::Minus1)?;
204    let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
205    let ideal_len =
206        ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
207    Ok(ideal_len.saturating_sub(len))
208}
209
210fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
211    match mode {
212        PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
213        PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
214        PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
215    }
216}
217
218fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
219    let len = xs.dim(D::Minus1)?;
220    if len < unpad_l + unpad_r {
221        candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
222    }
223    xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
224}
225
226#[derive(Debug, Clone)]
227pub struct StreamableConv1d {
228    conv: NormConv1d,
229    causal: bool,
230    pad_mode: PadMode,
231    state_prev_xs: StreamTensor,
232    left_pad_applied: bool,
233    kernel_size: usize,
234    span: tracing::Span,
235}
236
237impl StreamableConv1d {
238    #[allow(clippy::too_many_arguments)]
239    pub fn new(
240        in_c: usize,
241        out_c: usize,
242        k_size: usize,
243        stride: usize,
244        dilation: usize,
245        groups: usize,
246        bias: bool,
247        causal: bool,
248        norm: Option<Norm>,
249        pad_mode: PadMode,
250        vb: VarBuilder,
251    ) -> Result<Self> {
252        let cfg = candle_nn::Conv1dConfig {
253            padding: 0,
254            stride,
255            dilation,
256            groups,
257            cudnn_fwd_algo: Some(candle::conv::CudnnFwdAlgo::ImplicitGemm),
258        };
259        let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb.pp("conv"))?;
260        if k_size < stride {
261            candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
262        }
263        Ok(Self {
264            conv,
265            causal,
266            pad_mode,
267            state_prev_xs: StreamTensor::empty(),
268            left_pad_applied: false,
269            kernel_size: k_size,
270            span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
271        })
272    }
273
274    pub fn reset_batch_idx(&mut self, batch_idx: usize, _batch_size: usize) -> Result<()> {
275        if let Some(v) = self.state_prev_xs.as_option() {
276            let v = v.contiguous()?;
277            v.i(batch_idx..(1 + batch_idx))?.zero_set()?;
278            self.state_prev_xs = StreamTensor::from_tensor(v);
279        }
280        Ok(())
281    }
282}
283
284impl Module for StreamableConv1d {
285    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
286        let _enter = self.span.enter();
287        let (_b, _t, _c) = xs.dims3()?;
288        let k_size = self.conv.conv.weight().dim(D::Minus1)?;
289        let conv_cfg = self.conv.conv.config();
290        // Effective kernel size with dilations.
291        let k_size = (k_size - 1) * conv_cfg.dilation + 1;
292        let padding_total = k_size - conv_cfg.stride;
293        let extra_padding =
294            get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
295        let xs = if self.causal {
296            pad1d(xs, padding_total, extra_padding, self.pad_mode)?
297        } else {
298            let padding_right = padding_total / 2;
299            let padding_left = padding_total - padding_right;
300            pad1d(xs, padding_left, padding_right + extra_padding, self.pad_mode)?
301        };
302        xs.apply(&self.conv)
303    }
304}
305
306impl StreamingModule for StreamableConv1d {
307    fn reset_state(&mut self) {
308        self.state_prev_xs.reset();
309        self.left_pad_applied = false;
310    }
311
312    fn step(&mut self, xs: &StreamTensor, mask: &StreamMask) -> Result<StreamTensor> {
313        let _enter = self.span.enter();
314        let xs = match xs.as_option() {
315            None => return Ok(().into()),
316            Some(xs) => xs.clone(),
317        };
318        let xs = if self.left_pad_applied {
319            xs
320        } else {
321            self.left_pad_applied = true;
322            let k_size = self.conv.conv.weight().dim(D::Minus1)?;
323            let conv_cfg = self.conv.conv.config();
324            let k_size = (k_size - 1) * conv_cfg.dilation + 1;
325            let padding_total = k_size - conv_cfg.stride;
326            pad1d(&xs, padding_total, 0, self.pad_mode)?
327        };
328        let cfg = self.conv.conv.config();
329        let stride = cfg.stride;
330        let dilation = cfg.dilation;
331        let kernel = (self.kernel_size - 1) * dilation + 1;
332        let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
333        let seq_len = xs.seq_len(D::Minus1)?;
334        let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
335        let (state_prev_xs, ys) = if num_frames > 0 {
336            let offset = num_frames * stride;
337            let state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
338            let in_l = (num_frames - 1) * stride + kernel;
339            let xs = xs.narrow(D::Minus1, 0, in_l)?;
340            // We apply the underlying convtr directly rather than through forward so as
341            // not to apply any padding here.
342            let ys = xs.apply(&self.conv.conv)?;
343            (state_prev_xs, ys)
344        } else {
345            (xs, StreamTensor::empty())
346        };
347        let state_prev_xs = match mask.as_option() {
348            None => state_prev_xs,
349            Some(mask) => match (state_prev_xs.as_option(), self.state_prev_xs.as_option()) {
350                (None, None) => state_prev_xs,
351                (Some(state_prev_xs), None) => {
352                    let z = state_prev_xs.zeros_like()?;
353                    let mask = mask.reshape(((), 1, 1))?.broadcast_as(state_prev_xs.shape())?;
354                    mask.where_cond(state_prev_xs, &z)?.into()
355                }
356                (None, Some(_)) => {
357                    candle::bail!("streaming conv1d should only be used with constant steps")
358                }
359                (Some(prev_xs), Some(prev_prev_xs)) => {
360                    if prev_xs.shape() != prev_prev_xs.shape() {
361                        candle::bail!("streaming conv1d should only be used with constant steps {prev_xs:?} {prev_prev_xs:?}")
362                    }
363                    let mask = mask.reshape(((), 1, 1))?.broadcast_as(prev_xs.shape())?;
364                    mask.where_cond(prev_xs, prev_prev_xs)?.into()
365                }
366            },
367        };
368        self.state_prev_xs = state_prev_xs;
369        Ok(ys)
370    }
371}
372
373#[derive(Debug, Clone)]
374pub struct StreamableConvTranspose1d {
375    convtr: NormConvTranspose1d,
376    causal: bool,
377    state_prev_ys: StreamTensor,
378    kernel_size: usize,
379    span: tracing::Span,
380}
381
382impl StreamableConvTranspose1d {
383    #[allow(clippy::too_many_arguments)]
384    pub fn new(
385        in_c: usize,
386        out_c: usize,
387        k_size: usize,
388        stride: usize,
389        groups: usize,
390        bias: bool,
391        causal: bool,
392        norm: Option<Norm>,
393        vb: VarBuilder,
394    ) -> Result<Self> {
395        let convtr = NormConvTranspose1d::new(
396            in_c,
397            out_c,
398            k_size,
399            causal,
400            norm,
401            bias,
402            stride,
403            groups,
404            vb.pp("convtr"),
405        )?;
406        Ok(Self {
407            convtr,
408            causal,
409            kernel_size: k_size,
410            state_prev_ys: StreamTensor::empty(),
411            span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
412        })
413    }
414
415    pub fn reset_batch_idx(&mut self, batch_idx: usize, _batch_size: usize) -> Result<()> {
416        if let Some(v) = self.state_prev_ys.as_option() {
417            let v = v.contiguous()?;
418            v.i(batch_idx..(1 + batch_idx))?.zero_set()?;
419            self.state_prev_ys = v.into();
420        }
421        Ok(())
422    }
423}
424
425impl Module for StreamableConvTranspose1d {
426    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
427        let _enter = self.span.enter();
428        let k_size = self.convtr.k_size;
429        let stride = self.convtr.stride;
430        let padding_total = k_size.saturating_sub(stride);
431        let xs = xs.apply(&self.convtr)?;
432        if self.causal {
433            // This corresponds to trim_right_ratio = 1.
434            unpad1d(&xs, 0, padding_total)
435        } else {
436            let padding_right = padding_total / 2;
437            let padding_left = padding_total - padding_right;
438            unpad1d(&xs, padding_left, padding_right)
439        }
440    }
441}
442
443impl StreamingModule for StreamableConvTranspose1d {
444    fn reset_state(&mut self) {
445        self.state_prev_ys.reset()
446    }
447
448    fn step(&mut self, xs: &StreamTensor, mask: &StreamMask) -> Result<StreamTensor> {
449        let _enter = self.span.enter();
450        let xs = match xs.as_option() {
451            Some(xs) => xs,
452            None => return Ok(StreamTensor::empty()),
453        };
454        let stride = self.convtr.stride;
455        // We apply the underlying convtr directly rather than through forward so as
456        // not to apply any padding here.
457        let ys = self.convtr.forward(xs)?;
458        let ot = ys.dim(D::Minus1)?;
459        let ys = match self.state_prev_ys.as_option() {
460            None => ys,
461            Some(prev_ys) => {
462                let pt = prev_ys.dim(D::Minus1)?;
463                // Remove the bias as it will be applied multiple times.
464                let prev_ys = match &self.convtr.bs {
465                    None => prev_ys.clone(),
466                    Some(bias) => {
467                        let bias = bias.reshape((1, (), 1))?;
468                        prev_ys.broadcast_sub(&bias)?
469                    }
470                };
471                let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
472                let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
473                Tensor::cat(&[ys1, ys2], D::Minus1)?
474            }
475        };
476        let invalid_steps = self.kernel_size - stride;
477        let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
478        let prev_ys = match mask.as_option() {
479            None => prev_ys,
480            Some(mask) => match (prev_ys.as_option(), self.state_prev_ys.as_option()) {
481                (None, None) => prev_ys,
482                (Some(prev_ys), None) => {
483                    let z = prev_ys.zeros_like()?;
484                    let mask = mask.reshape(((), 1, 1))?.broadcast_as(prev_ys.shape())?;
485                    mask.where_cond(prev_ys, &z)?.into()
486                }
487                (None, Some(_)) => {
488                    candle::bail!("streaming conv-tr1d should only be used with constant steps")
489                }
490                (Some(prev_ys), Some(prev_prev_ys)) => {
491                    if prev_ys.shape() != prev_prev_ys.shape() {
492                        candle::bail!("streaming conv-tr1d should only be used with constant steps {prev_ys:?} {prev_prev_ys:?}")
493                    }
494                    let mask = mask.reshape(((), 1, 1))?.broadcast_as(prev_ys.shape())?;
495                    mask.where_cond(prev_ys, prev_prev_ys)?.into()
496                }
497            },
498        };
499        self.state_prev_ys = prev_ys;
500        Ok(ys)
501    }
502}
503
504#[derive(Debug, Clone)]
505pub struct ConvDownsample1d {
506    conv: StreamableConv1d,
507}
508
509impl ConvDownsample1d {
510    pub fn new(
511        stride: usize,
512        dim: usize,
513        causal: bool,
514        learnt: bool,
515        vb: VarBuilder,
516    ) -> Result<Self> {
517        if !learnt {
518            candle::bail!("only learnt=true is supported")
519        }
520        let conv = StreamableConv1d::new(
521            /* in_c */ dim,
522            /* out_c */ dim,
523            /* k_size_c */ 2 * stride,
524            /* stride */ stride,
525            /* dilation */ 1,
526            /* groups */ 1, // channel_wise = false
527            /* bias */ false,
528            /* causal */ causal,
529            /* norm */ None,
530            /* pad_mode */ PadMode::Replicate,
531            vb.pp("conv"),
532        )?;
533        Ok(Self { conv })
534    }
535
536    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
537        self.conv.reset_batch_idx(batch_idx, batch_size)
538    }
539}
540
541impl Module for ConvDownsample1d {
542    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
543        xs.apply(&self.conv)
544    }
545}
546
547impl StreamingModule for ConvDownsample1d {
548    fn reset_state(&mut self) {
549        self.conv.reset_state()
550    }
551
552    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
553        self.conv.step(xs, m)
554    }
555}
556
557#[derive(Debug, Clone)]
558pub struct ConvTrUpsample1d {
559    convtr: StreamableConvTranspose1d,
560}
561
562impl ConvTrUpsample1d {
563    pub fn new(
564        stride: usize,
565        dim: usize,
566        causal: bool,
567        learnt: bool,
568        vb: VarBuilder,
569    ) -> Result<Self> {
570        if !learnt {
571            candle::bail!("only learnt=true is supported")
572        }
573        let convtr = StreamableConvTranspose1d::new(
574            dim,
575            dim,
576            /* k_size */ 2 * stride,
577            /* stride */ stride,
578            /* groups */ dim,
579            /* bias */ false,
580            /* causal */ causal,
581            /* norm */ None,
582            vb.pp("convtr"),
583        )?;
584        Ok(Self { convtr })
585    }
586
587    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
588        self.convtr.reset_batch_idx(batch_idx, batch_size)
589    }
590}
591
592impl Module for ConvTrUpsample1d {
593    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
594        xs.apply(&self.convtr)
595    }
596}
597
598impl StreamingModule for ConvTrUpsample1d {
599    fn reset_state(&mut self) {
600        self.convtr.reset_state()
601    }
602
603    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
604        self.convtr.step(xs, m)
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611    use candle::IndexOp;
612
613    fn run_conv1d(
614        k_size: usize,
615        stride: usize,
616        dilation: usize,
617        step_size: usize,
618        len: usize,
619        bias: bool,
620    ) -> Result<()> {
621        // TODO: We should ensure for the seed to be constant when running these tests.
622        let dev = &candle::Device::Cpu;
623        let vm = candle_nn::VarMap::new();
624        let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
625        let conv1d = StreamableConv1d::new(
626            /* in_c */ 2,
627            /* out_c */ 3,
628            /* k_size */ k_size,
629            /* stride */ stride,
630            /* dilation */ dilation,
631            /* groups */ 1,
632            /* bias */ bias,
633            /* causal */ true,
634            /* norm */ None,
635            /* pad_mode */ PadMode::Constant,
636            vb,
637        )?;
638        let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
639        let ys = conv1d.forward(&xs)?;
640        let mut conv1d = conv1d;
641        let mut ys_steps = vec![];
642        for idx in 0..len {
643            let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
644            let ys = conv1d.step(&xs.into(), &().into())?;
645            if let Some(ys) = ys.as_option() {
646                ys_steps.push(ys.clone())
647            }
648        }
649        let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
650        let diff = (&ys - &ys_steps)?.abs()?.flatten_all()?.max(0)?.to_vec0::<f32>()?;
651        if diff > 1e-5 {
652            println!("{xs}");
653            println!("{ys}");
654            println!("{ys_steps}");
655            candle::bail!("larger diff than expected {diff}")
656        }
657        Ok(())
658    }
659
660    fn run_conv_tr1d(
661        k_size: usize,
662        stride: usize,
663        step_size: usize,
664        len: usize,
665        bias: bool,
666    ) -> Result<()> {
667        // TODO: We should ensure for the seed to be constant when running these tests.
668        let dev = &candle::Device::Cpu;
669        let vm = candle_nn::VarMap::new();
670        let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
671        let conv1d = StreamableConvTranspose1d::new(
672            /* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
673            /* stride */ stride, /* groups */ 1, /* bias */ bias,
674            /* causal */ true, /* norm */ None, vb,
675        )?;
676        let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
677        let ys = conv1d.forward(&xs)?;
678        let mut conv1d = conv1d;
679        let mut ys_steps = vec![];
680        for idx in 0..len {
681            let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
682            let ys = conv1d.step(&xs.into(), &().into())?;
683            if let Some(ys) = ys.as_option() {
684                ys_steps.push(ys.clone())
685            }
686        }
687        let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
688        let diff = (&ys - &ys_steps)?.abs()?.flatten_all()?.max(0)?.to_vec0::<f32>()?;
689        if diff > 1e-5 {
690            println!("{xs}");
691            println!("{ys}");
692            println!("{ys_steps}");
693            candle::bail!("larger diff than expected {diff}")
694        }
695        Ok(())
696    }
697
698    #[test]
699    fn conv1d() -> Result<()> {
700        for step_size in [1, 2, 3] {
701            for bias in [false, true] {
702                run_conv1d(1, 1, 1, step_size, 5, bias)?;
703                run_conv1d(2, 1, 1, step_size, 5, bias)?;
704                run_conv1d(2, 2, 1, step_size, 6, bias)?;
705                run_conv1d(3, 2, 1, step_size, 8, bias)?;
706                run_conv1d(3, 2, 2, step_size, 8, bias)?;
707            }
708        }
709        Ok(())
710    }
711
712    #[test]
713    fn conv_tr1d() -> Result<()> {
714        for step_size in [1, 2, 3] {
715            for bias in [false, true] {
716                run_conv_tr1d(1, 1, step_size, 5, bias)?;
717                run_conv_tr1d(2, 1, step_size, 5, bias)?;
718                run_conv_tr1d(3, 1, step_size, 5, bias)?;
719                run_conv_tr1d(3, 2, step_size, 5, bias)?;
720            }
721        }
722        Ok(())
723    }
724}