moshi_db/
mimi.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use crate::streaming::{StreamMask, StreamTensor, StreamingModule};
6use crate::{conv, quantization, seanet, transformer};
7use candle::{DType, Device, Module, Result, Tensor};
8use candle_nn::VarBuilder;
9
10#[derive(Debug, Copy, Clone, PartialEq, Eq)]
11pub enum ResampleMethod {
12    Conv,
13    Interpolate,
14}
15
16#[derive(Debug, Clone)]
17pub struct Config {
18    pub channels: usize,
19    pub sample_rate: f64,
20    pub frame_rate: f64,
21    pub renormalize: bool,
22    pub resample_method: ResampleMethod,
23    pub seanet: seanet::Config,
24    pub transformer: transformer::Config,
25    pub quantizer_n_q: usize,
26    pub quantizer_bins: usize,
27    pub quantizer_dim: usize,
28}
29
30impl Config {
31    // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml
32    pub fn v0_1(num_codebooks: Option<usize>) -> Self {
33        let seanet_cfg = seanet::Config {
34            dimension: 512,
35            channels: 1,
36            causal: true,
37            n_filters: 64,
38            n_residual_layers: 1,
39            activation: candle_nn::Activation::Elu(1.),
40            compress: 2,
41            dilation_base: 2,
42            disable_norm_outer_blocks: 0,
43            final_activation: None,
44            kernel_size: 7,
45            residual_kernel_size: 3,
46            last_kernel_size: 3,
47            lstm: 0,
48            norm: conv::Norm::WeightNorm,
49            pad_mode: conv::PadMode::Constant,
50            ratios: vec![8, 6, 5, 4],
51            true_skip: true,
52        };
53        let transformer_cfg = transformer::Config {
54            d_model: seanet_cfg.dimension,
55            num_heads: 8,
56            num_layers: 8,
57            causal: true,
58            norm_first: true,
59            bias_ff: false,
60            bias_attn: false,
61            layer_scale: Some(0.01),
62            context: 250,
63            conv_kernel_size: 5,
64            use_conv_bias: true,
65            use_conv_block: false,
66            cross_attention: None,
67            max_period: 10000,
68            gating: None,
69            norm: crate::NormType::LayerNorm,
70            positional_embedding: transformer::PositionalEmbedding::Rope,
71
72            dim_feedforward: 2048,
73            kv_repeat: 1,
74            conv_layout: true, // see builders.py
75            max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins.
76            shared_cross_attn: false,
77        };
78        Config {
79            channels: 1,
80            sample_rate: 24_000.,
81            frame_rate: 12.5,
82            renormalize: true,
83            resample_method: ResampleMethod::Conv,
84            seanet: seanet_cfg,
85            transformer: transformer_cfg,
86            quantizer_n_q: num_codebooks.unwrap_or(16),
87            quantizer_bins: 2048,
88            quantizer_dim: 256,
89        }
90    }
91}
92
93#[derive(Debug, Clone)]
94pub struct Mimi {
95    encoder: seanet::SeaNetEncoder,
96    decoder: seanet::SeaNetDecoder,
97    encoder_transformer: transformer::Transformer,
98    decoder_transformer: transformer::Transformer,
99    downsample: conv::ConvDownsample1d,
100    upsample: conv::ConvTrUpsample1d,
101    quantizer: quantization::SplitResidualVectorQuantizer,
102    config: Config,
103}
104
105impl Mimi {
106    pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
107        Self::new_(None, cfg, vb)
108    }
109
110    pub fn batched(batch_size: usize, cfg: Config, vb: VarBuilder) -> Result<Self> {
111        Self::new_(Some(batch_size), cfg, vb)
112    }
113
114    fn new_(batch_size: Option<usize>, cfg: Config, vb: VarBuilder) -> Result<Self> {
115        let dim = cfg.seanet.dimension;
116        let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
117        let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
118        let encoder_transformer = transformer::Transformer::new(
119            batch_size,
120            dim,
121            &cfg.transformer,
122            vb.pp("encoder_transformer"),
123        )?;
124        let decoder_transformer = transformer::Transformer::new(
125            batch_size,
126            dim,
127            &cfg.transformer,
128            vb.pp("decoder_transformer"),
129        )?;
130        let quantizer = quantization::SplitResidualVectorQuantizer::new(
131            /* dim */ cfg.quantizer_dim,
132            /* input_dim */ Some(dim),
133            /* output_dim */ Some(dim),
134            /* n_q */ cfg.quantizer_n_q,
135            /* bins */ cfg.quantizer_bins,
136            vb.pp("quantizer"),
137        )?;
138        let encoder_frame_rate =
139            cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
140
141        let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
142        // `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate.
143        let downsample = conv::ConvDownsample1d::new(
144            /* stride */ downsample_stride,
145            /* dim */ dim,
146            /* causal */ true,
147            /* learnt */ true,
148            vb.pp("downsample"),
149        )?;
150        let upsample = conv::ConvTrUpsample1d::new(
151            /* stride */ downsample_stride,
152            /* dim */ dim,
153            /* causal */ true,
154            /* learnt */ true,
155            vb.pp("upsample"),
156        )?;
157
158        Ok(Self {
159            encoder,
160            decoder,
161            encoder_transformer,
162            decoder_transformer,
163            quantizer,
164            downsample,
165            upsample,
166            config: cfg,
167        })
168    }
169
170    pub fn config(&self) -> &Config {
171        &self.config
172    }
173
174    pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
175        let xs = self.encoder.forward(xs)?;
176        self.encoder_transformer.reset_state();
177        let xs = self.encoder_transformer.forward(&xs)?;
178        let xs = &xs[0];
179        xs.apply(&self.downsample)
180    }
181
182    pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
183        let xs = self.encoder.forward(xs)?;
184        self.encoder_transformer.reset_state();
185        let xs = self.encoder_transformer.forward(&xs)?;
186        let xs = &xs[0];
187        let xs = xs.apply(&self.downsample)?;
188        let codes = self.quantizer.encode(&xs)?;
189        Ok(codes)
190    }
191
192    pub fn encode_step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
193        let xs = self.encoder.step(xs, m)?;
194        let xs = self.encoder_transformer.step(&xs, m)?;
195        let xs = self.downsample.step(&xs, m)?;
196        match xs.as_option() {
197            None => Ok(().into()),
198            Some(xs) => {
199                let codes = self.quantizer.encode(xs)?;
200                Ok(codes.into())
201            }
202        }
203    }
204
205    pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
206        let emb = self.quantizer.decode(codes)?;
207        let emb = emb.apply(&self.upsample)?;
208        self.decoder_transformer.reset_state();
209        let outs = self.decoder_transformer.forward(&emb)?;
210        let out = &outs[0];
211        self.decoder.forward(out)
212    }
213
214    pub fn decode_step(&mut self, codes: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
215        let emb = match codes.as_option() {
216            Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
217            None => StreamTensor::empty(),
218        };
219        let emb = self.upsample.step(&emb, m)?;
220        let out = self.decoder_transformer.step(&emb, m)?;
221        self.decoder.step(&out, m)
222    }
223
224    pub fn reset_state(&mut self) {
225        self.encoder.reset_state();
226        self.encoder_transformer.reset_state();
227        self.decoder.reset_state();
228        self.decoder_transformer.reset_state();
229        self.upsample.reset_state();
230        self.downsample.reset_state();
231    }
232
233    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
234        self.encoder_transformer.reset_batch_idx(batch_idx, batch_size)?;
235        self.encoder_transformer.reset_batch_idx(batch_idx, batch_size)?;
236        self.encoder.reset_batch_idx(batch_idx, batch_size)?;
237        self.decoder.reset_batch_idx(batch_idx, batch_size)?;
238        self.upsample.reset_batch_idx(batch_idx, batch_size)?;
239        self.downsample.reset_batch_idx(batch_idx, batch_size)?;
240        Ok(())
241    }
242}
243
244pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Mimi> {
245    let vb =
246        unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
247    let cfg = Config::v0_1(num_codebooks);
248    let mimi = Mimi::new(cfg, vb)?;
249    Ok(mimi)
250}
251
252pub fn load_b(
253    batch_size: Option<usize>,
254    model_file: &str,
255    num_codebooks: Option<usize>,
256    dev: &Device,
257) -> Result<Mimi> {
258    let vb =
259        unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
260    let cfg = Config::v0_1(num_codebooks);
261    let mimi = Mimi::new_(batch_size, cfg, vb)?;
262    Ok(mimi)
263}