moshi_db/
seanet.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use crate::streaming::{self, StreamMask, StreamTensor, StreamingModule};
6use candle::{Module, Result, Tensor};
7use candle_nn::VarBuilder;
8
9use crate::conv::{StreamableConv1d, StreamableConvTranspose1d};
10
11#[derive(Debug, Clone)]
12pub struct Config {
13    pub dimension: usize,
14    pub channels: usize,
15    pub causal: bool,
16    pub n_filters: usize,
17    pub n_residual_layers: usize,
18    pub ratios: Vec<usize>,
19    pub activation: candle_nn::Activation,
20    pub norm: crate::conv::Norm,
21    pub kernel_size: usize,
22    pub residual_kernel_size: usize,
23    pub last_kernel_size: usize,
24    pub dilation_base: usize,
25    pub pad_mode: crate::conv::PadMode,
26    pub true_skip: bool,
27    pub compress: usize,
28    pub lstm: usize,
29    pub disable_norm_outer_blocks: usize,
30    pub final_activation: Option<candle_nn::Activation>,
31}
32
33#[derive(Debug, Clone)]
34pub struct SeaNetResnetBlock {
35    block: Vec<StreamableConv1d>,
36    shortcut: Option<StreamableConv1d>,
37    activation: candle_nn::Activation,
38    skip_op: streaming::StreamingBinOp,
39    span: tracing::Span,
40}
41
42impl SeaNetResnetBlock {
43    #[allow(clippy::too_many_arguments)]
44    pub fn new(
45        dim: usize,
46        k_sizes_and_dilations: &[(usize, usize)],
47        activation: candle_nn::Activation,
48        norm: Option<crate::conv::Norm>,
49        causal: bool,
50        pad_mode: crate::conv::PadMode,
51        compress: usize,
52        true_skip: bool,
53        vb: VarBuilder,
54    ) -> Result<Self> {
55        let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
56        let hidden = dim / compress;
57        let vb_b = vb.pp("block");
58        for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
59            let in_c = if i == 0 { dim } else { hidden };
60            let out_c = if i == k_sizes_and_dilations.len() - 1 { dim } else { hidden };
61            let c = StreamableConv1d::new(
62                in_c,
63                out_c,
64                /* k_size */ *k_size,
65                /* stride */ 1,
66                /* dilation */ *dilation,
67                /* groups */ 1,
68                /* bias */ true,
69                /* causal */ causal,
70                /* norm */ norm,
71                /* pad_mode */ pad_mode,
72                vb_b.pp(2 * i + 1),
73            )?;
74            block.push(c)
75        }
76        let shortcut = if true_skip {
77            None
78        } else {
79            let c = StreamableConv1d::new(
80                dim,
81                dim,
82                /* k_size */ 1,
83                /* stride */ 1,
84                /* dilation */ 1,
85                /* groups */ 1,
86                /* bias */ true,
87                /* causal */ causal,
88                /* norm */ norm,
89                /* pad_mode */ pad_mode,
90                vb.pp("shortcut"),
91            )?;
92            Some(c)
93        };
94        Ok(Self {
95            block,
96            shortcut,
97            activation,
98            skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
99            span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
100        })
101    }
102
103    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
104        for b in self.block.iter_mut() {
105            b.reset_batch_idx(batch_idx, batch_size)?;
106        }
107        if let Some(shortcut) = self.shortcut.as_mut() {
108            shortcut.reset_batch_idx(batch_idx, batch_size)?;
109        }
110        self.skip_op.reset_batch_idx(batch_idx, batch_size)?;
111        Ok(())
112    }
113}
114
115impl Module for SeaNetResnetBlock {
116    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
117        let _enter = self.span.enter();
118        let mut ys = xs.clone();
119        for block in self.block.iter() {
120            ys = ys.apply(&self.activation)?.apply(block)?;
121        }
122        match self.shortcut.as_ref() {
123            None => ys + xs,
124            Some(shortcut) => ys + xs.apply(shortcut),
125        }
126    }
127}
128
129impl StreamingModule for SeaNetResnetBlock {
130    fn reset_state(&mut self) {
131        self.skip_op.reset_state();
132        for block in self.block.iter_mut() {
133            block.reset_state()
134        }
135        if let Some(shortcut) = self.shortcut.as_mut() {
136            shortcut.reset_state()
137        }
138    }
139
140    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
141        let _enter = self.span.enter();
142        let mut ys = xs.clone();
143        for block in self.block.iter_mut() {
144            ys = block.step(&ys.apply(&self.activation)?, m)?;
145        }
146        match self.shortcut.as_mut() {
147            None => self.skip_op.step(&ys, xs, m),
148            Some(shortcut) => self.skip_op.step(&ys, &shortcut.step(xs, m)?, m),
149        }
150    }
151}
152
153#[derive(Debug, Clone)]
154struct EncoderLayer {
155    residuals: Vec<SeaNetResnetBlock>,
156    downsample: StreamableConv1d,
157}
158
159#[derive(Debug, Clone)]
160pub struct SeaNetEncoder {
161    init_conv1d: StreamableConv1d,
162    activation: candle_nn::Activation,
163    layers: Vec<EncoderLayer>,
164    final_conv1d: StreamableConv1d,
165    span: tracing::Span,
166}
167
168impl SeaNetEncoder {
169    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
170        if cfg.lstm > 0 {
171            candle::bail!("seanet lstm is not supported")
172        }
173        let n_blocks = 2 + cfg.ratios.len();
174        let mut mult = 1usize;
175        let init_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) };
176        let mut layer_idx = 0;
177        let vb = vb.pp("model");
178        let init_conv1d = StreamableConv1d::new(
179            cfg.channels,
180            mult * cfg.n_filters,
181            cfg.kernel_size,
182            /* stride */ 1,
183            /* dilation */ 1,
184            /* groups */ 1,
185            /* bias */ true,
186            /* causal */ cfg.causal,
187            /* norm */ init_norm,
188            /* pad_mode */ cfg.pad_mode,
189            vb.pp(layer_idx),
190        )?;
191        layer_idx += 1;
192        let mut layers = Vec::with_capacity(cfg.ratios.len());
193
194        for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
195            let norm = if cfg.disable_norm_outer_blocks >= i + 2 { None } else { Some(cfg.norm) };
196            let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
197            for j in 0..cfg.n_residual_layers {
198                let resnet_block = SeaNetResnetBlock::new(
199                    mult * cfg.n_filters,
200                    &[(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1)],
201                    cfg.activation,
202                    norm,
203                    cfg.causal,
204                    cfg.pad_mode,
205                    cfg.compress,
206                    cfg.true_skip,
207                    vb.pp(layer_idx),
208                )?;
209                residuals.push(resnet_block);
210                layer_idx += 1;
211            }
212            let downsample = StreamableConv1d::new(
213                mult * cfg.n_filters,
214                mult * cfg.n_filters * 2,
215                /* k_size */ ratio * 2,
216                /* stride */ ratio,
217                /* dilation */ 1,
218                /* groups */ 1,
219                /* bias */ true,
220                /* causal */ true,
221                /* norm */ norm,
222                /* pad_mode */ cfg.pad_mode,
223                vb.pp(layer_idx + 1),
224            )?;
225            layer_idx += 2;
226            let layer = EncoderLayer { downsample, residuals };
227            layers.push(layer);
228            mult *= 2
229        }
230
231        let final_norm =
232            if cfg.disable_norm_outer_blocks >= n_blocks { None } else { Some(cfg.norm) };
233        let final_conv1d = StreamableConv1d::new(
234            mult * cfg.n_filters,
235            cfg.dimension,
236            cfg.last_kernel_size,
237            /* stride */ 1,
238            /* dilation */ 1,
239            /* groups */ 1,
240            /* bias */ true,
241            /* causal */ cfg.causal,
242            /* norm */ final_norm,
243            /* pad_mode */ cfg.pad_mode,
244            vb.pp(layer_idx + 1),
245        )?;
246        Ok(Self {
247            init_conv1d,
248            activation: cfg.activation,
249            layers,
250            final_conv1d,
251            span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
252        })
253    }
254
255    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
256        self.init_conv1d.reset_batch_idx(batch_idx, batch_size)?;
257        self.final_conv1d.reset_batch_idx(batch_idx, batch_size)?;
258        for layer in self.layers.iter_mut() {
259            layer.downsample.reset_batch_idx(batch_idx, batch_size)?;
260            for l in layer.residuals.iter_mut() {
261                l.reset_batch_idx(batch_idx, batch_size)?;
262            }
263        }
264        Ok(())
265    }
266}
267
268impl Module for SeaNetEncoder {
269    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
270        let _enter = self.span.enter();
271        let mut xs = xs.apply(&self.init_conv1d)?;
272        for layer in self.layers.iter() {
273            for residual in layer.residuals.iter() {
274                xs = xs.apply(residual)?
275            }
276            xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
277        }
278        xs.apply(&self.activation)?.apply(&self.final_conv1d)
279    }
280}
281
282impl StreamingModule for SeaNetEncoder {
283    fn reset_state(&mut self) {
284        self.init_conv1d.reset_state();
285        self.layers.iter_mut().for_each(|v| {
286            v.residuals.iter_mut().for_each(|v| v.reset_state());
287            v.downsample.reset_state()
288        });
289        self.final_conv1d.reset_state();
290    }
291
292    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
293        let _enter = self.span.enter();
294        let mut xs = self.init_conv1d.step(xs, m)?;
295        for layer in self.layers.iter_mut() {
296            for residual in layer.residuals.iter_mut() {
297                xs = residual.step(&xs, m)?;
298            }
299            xs = layer.downsample.step(&xs.apply(&self.activation)?, m)?;
300        }
301        self.final_conv1d.step(&xs.apply(&self.activation)?, m)
302    }
303}
304
305#[derive(Debug, Clone)]
306struct DecoderLayer {
307    upsample: StreamableConvTranspose1d,
308    residuals: Vec<SeaNetResnetBlock>,
309}
310
311#[derive(Debug, Clone)]
312pub struct SeaNetDecoder {
313    init_conv1d: StreamableConv1d,
314    activation: candle_nn::Activation,
315    layers: Vec<DecoderLayer>,
316    final_conv1d: StreamableConv1d,
317    final_activation: Option<candle_nn::Activation>,
318    span: tracing::Span,
319}
320
321impl SeaNetDecoder {
322    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
323        if cfg.lstm > 0 {
324            candle::bail!("seanet lstm is not supported")
325        }
326        let n_blocks = 2 + cfg.ratios.len();
327        let mut mult = 1 << cfg.ratios.len();
328        let init_norm =
329            if cfg.disable_norm_outer_blocks == n_blocks { None } else { Some(cfg.norm) };
330        let mut layer_idx = 0;
331        let vb = vb.pp("model");
332        let init_conv1d = StreamableConv1d::new(
333            cfg.dimension,
334            mult * cfg.n_filters,
335            cfg.kernel_size,
336            /* stride */ 1,
337            /* dilation */ 1,
338            /* groups */ 1,
339            /* bias */ true,
340            /* causal */ cfg.causal,
341            /* norm */ init_norm,
342            /* pad_mode */ cfg.pad_mode,
343            vb.pp(layer_idx),
344        )?;
345        layer_idx += 1;
346        let mut layers = Vec::with_capacity(cfg.ratios.len());
347        for (i, &ratio) in cfg.ratios.iter().enumerate() {
348            let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
349                None
350            } else {
351                Some(cfg.norm)
352            };
353            let upsample = StreamableConvTranspose1d::new(
354                mult * cfg.n_filters,
355                mult * cfg.n_filters / 2,
356                /* k_size */ ratio * 2,
357                /* stride */ ratio,
358                /* groups */ 1,
359                /* bias */ true,
360                /* causal */ true,
361                /* norm */ norm,
362                vb.pp(layer_idx + 1),
363            )?;
364            layer_idx += 2;
365
366            let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
367            for j in 0..cfg.n_residual_layers {
368                let resnet_block = SeaNetResnetBlock::new(
369                    mult * cfg.n_filters / 2,
370                    &[(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), (1, 1)],
371                    cfg.activation,
372                    norm,
373                    cfg.causal,
374                    cfg.pad_mode,
375                    cfg.compress,
376                    cfg.true_skip,
377                    vb.pp(layer_idx),
378                )?;
379                residuals.push(resnet_block);
380                layer_idx += 1;
381            }
382            let layer = DecoderLayer { upsample, residuals };
383            layers.push(layer);
384            mult /= 2
385        }
386        let final_norm = if cfg.disable_norm_outer_blocks >= 1 { None } else { Some(cfg.norm) };
387        let final_conv1d = StreamableConv1d::new(
388            cfg.n_filters,
389            cfg.channels,
390            cfg.last_kernel_size,
391            /* stride */ 1,
392            /* dilation */ 1,
393            /* groups */ 1,
394            /* bias */ true,
395            /* causal */ cfg.causal,
396            /* norm */ final_norm,
397            /* pad_mode */ cfg.pad_mode,
398            vb.pp(layer_idx + 1),
399        )?;
400        Ok(Self {
401            init_conv1d,
402            activation: cfg.activation,
403            layers,
404            final_conv1d,
405            final_activation: cfg.final_activation,
406            span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
407        })
408    }
409
410    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
411        self.init_conv1d.reset_batch_idx(batch_idx, batch_size)?;
412        self.final_conv1d.reset_batch_idx(batch_idx, batch_size)?;
413        for layer in self.layers.iter_mut() {
414            layer.upsample.reset_batch_idx(batch_idx, batch_size)?;
415            for l in layer.residuals.iter_mut() {
416                l.reset_batch_idx(batch_idx, batch_size)?;
417            }
418        }
419        Ok(())
420    }
421}
422
423impl Module for SeaNetDecoder {
424    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
425        let _enter = self.span.enter();
426        let mut xs = xs.apply(&self.init_conv1d)?;
427        for layer in self.layers.iter() {
428            xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
429            for residual in layer.residuals.iter() {
430                xs = xs.apply(residual)?
431            }
432        }
433        let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
434        let xs = match self.final_activation.as_ref() {
435            None => xs,
436            Some(act) => xs.apply(act)?,
437        };
438        Ok(xs)
439    }
440}
441
442impl StreamingModule for SeaNetDecoder {
443    fn reset_state(&mut self) {
444        self.init_conv1d.reset_state();
445        self.layers.iter_mut().for_each(|v| {
446            v.residuals.iter_mut().for_each(|v| v.reset_state());
447            v.upsample.reset_state()
448        });
449        self.final_conv1d.reset_state();
450    }
451
452    fn step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
453        let _enter = self.span.enter();
454        let mut xs = self.init_conv1d.step(xs, m)?;
455        for layer in self.layers.iter_mut() {
456            xs = layer.upsample.step(&xs.apply(&self.activation)?, m)?;
457            for residual in layer.residuals.iter_mut() {
458                xs = residual.step(&xs, m)?;
459            }
460        }
461        let xs = self.final_conv1d.step(&xs.apply(&self.activation)?, m)?;
462        let xs = match self.final_activation.as_ref() {
463            None => xs,
464            Some(act) => xs.apply(act)?,
465        };
466        Ok(xs)
467    }
468}