1use crate::streaming::{StreamMask, StreamTensor, StreamingModule};
6use crate::{conv, quantization, seanet, transformer};
7use candle::{DType, Device, Module, Result, Tensor};
8use candle_nn::VarBuilder;
9
10#[derive(Debug, Copy, Clone, PartialEq, Eq)]
11pub enum ResampleMethod {
12 Conv,
13 Interpolate,
14}
15
16#[derive(Debug, Clone)]
17pub struct Config {
18 pub channels: usize,
19 pub sample_rate: f64,
20 pub frame_rate: f64,
21 pub renormalize: bool,
22 pub resample_method: ResampleMethod,
23 pub seanet: seanet::Config,
24 pub transformer: transformer::Config,
25 pub quantizer_n_q: usize,
26 pub quantizer_bins: usize,
27 pub quantizer_dim: usize,
28}
29
30impl Config {
31 pub fn v0_1(num_codebooks: Option<usize>) -> Self {
33 let seanet_cfg = seanet::Config {
34 dimension: 512,
35 channels: 1,
36 causal: true,
37 n_filters: 64,
38 n_residual_layers: 1,
39 activation: candle_nn::Activation::Elu(1.),
40 compress: 2,
41 dilation_base: 2,
42 disable_norm_outer_blocks: 0,
43 final_activation: None,
44 kernel_size: 7,
45 residual_kernel_size: 3,
46 last_kernel_size: 3,
47 lstm: 0,
48 norm: conv::Norm::WeightNorm,
49 pad_mode: conv::PadMode::Constant,
50 ratios: vec![8, 6, 5, 4],
51 true_skip: true,
52 };
53 let transformer_cfg = transformer::Config {
54 d_model: seanet_cfg.dimension,
55 num_heads: 8,
56 num_layers: 8,
57 causal: true,
58 norm_first: true,
59 bias_ff: false,
60 bias_attn: false,
61 layer_scale: Some(0.01),
62 context: 250,
63 conv_kernel_size: 5,
64 use_conv_bias: true,
65 use_conv_block: false,
66 cross_attention: None,
67 max_period: 10000,
68 gating: None,
69 norm: crate::NormType::LayerNorm,
70 positional_embedding: transformer::PositionalEmbedding::Rope,
71
72 dim_feedforward: 2048,
73 kv_repeat: 1,
74 conv_layout: true, max_seq_len: 8192, shared_cross_attn: false,
77 };
78 Config {
79 channels: 1,
80 sample_rate: 24_000.,
81 frame_rate: 12.5,
82 renormalize: true,
83 resample_method: ResampleMethod::Conv,
84 seanet: seanet_cfg,
85 transformer: transformer_cfg,
86 quantizer_n_q: num_codebooks.unwrap_or(16),
87 quantizer_bins: 2048,
88 quantizer_dim: 256,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
94pub struct Mimi {
95 encoder: seanet::SeaNetEncoder,
96 decoder: seanet::SeaNetDecoder,
97 encoder_transformer: transformer::Transformer,
98 decoder_transformer: transformer::Transformer,
99 downsample: conv::ConvDownsample1d,
100 upsample: conv::ConvTrUpsample1d,
101 quantizer: quantization::SplitResidualVectorQuantizer,
102 config: Config,
103}
104
105impl Mimi {
106 pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
107 Self::new_(None, cfg, vb)
108 }
109
110 pub fn batched(batch_size: usize, cfg: Config, vb: VarBuilder) -> Result<Self> {
111 Self::new_(Some(batch_size), cfg, vb)
112 }
113
114 fn new_(batch_size: Option<usize>, cfg: Config, vb: VarBuilder) -> Result<Self> {
115 let dim = cfg.seanet.dimension;
116 let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
117 let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
118 let encoder_transformer = transformer::Transformer::new(
119 batch_size,
120 dim,
121 &cfg.transformer,
122 vb.pp("encoder_transformer"),
123 )?;
124 let decoder_transformer = transformer::Transformer::new(
125 batch_size,
126 dim,
127 &cfg.transformer,
128 vb.pp("decoder_transformer"),
129 )?;
130 let quantizer = quantization::SplitResidualVectorQuantizer::new(
131 cfg.quantizer_dim,
132 Some(dim),
133 Some(dim),
134 cfg.quantizer_n_q,
135 cfg.quantizer_bins,
136 vb.pp("quantizer"),
137 )?;
138 let encoder_frame_rate =
139 cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
140
141 let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
142 let downsample = conv::ConvDownsample1d::new(
144 downsample_stride,
145 dim,
146 true,
147 true,
148 vb.pp("downsample"),
149 )?;
150 let upsample = conv::ConvTrUpsample1d::new(
151 downsample_stride,
152 dim,
153 true,
154 true,
155 vb.pp("upsample"),
156 )?;
157
158 Ok(Self {
159 encoder,
160 decoder,
161 encoder_transformer,
162 decoder_transformer,
163 quantizer,
164 downsample,
165 upsample,
166 config: cfg,
167 })
168 }
169
170 pub fn config(&self) -> &Config {
171 &self.config
172 }
173
174 pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
175 let xs = self.encoder.forward(xs)?;
176 self.encoder_transformer.reset_state();
177 let xs = self.encoder_transformer.forward(&xs)?;
178 let xs = &xs[0];
179 xs.apply(&self.downsample)
180 }
181
182 pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
183 let xs = self.encoder.forward(xs)?;
184 self.encoder_transformer.reset_state();
185 let xs = self.encoder_transformer.forward(&xs)?;
186 let xs = &xs[0];
187 let xs = xs.apply(&self.downsample)?;
188 let codes = self.quantizer.encode(&xs)?;
189 Ok(codes)
190 }
191
192 pub fn encode_step(&mut self, xs: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
193 let xs = self.encoder.step(xs, m)?;
194 let xs = self.encoder_transformer.step(&xs, m)?;
195 let xs = self.downsample.step(&xs, m)?;
196 match xs.as_option() {
197 None => Ok(().into()),
198 Some(xs) => {
199 let codes = self.quantizer.encode(xs)?;
200 Ok(codes.into())
201 }
202 }
203 }
204
205 pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
206 let emb = self.quantizer.decode(codes)?;
207 let emb = emb.apply(&self.upsample)?;
208 self.decoder_transformer.reset_state();
209 let outs = self.decoder_transformer.forward(&emb)?;
210 let out = &outs[0];
211 self.decoder.forward(out)
212 }
213
214 pub fn decode_step(&mut self, codes: &StreamTensor, m: &StreamMask) -> Result<StreamTensor> {
215 let emb = match codes.as_option() {
216 Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
217 None => StreamTensor::empty(),
218 };
219 let emb = self.upsample.step(&emb, m)?;
220 let out = self.decoder_transformer.step(&emb, m)?;
221 self.decoder.step(&out, m)
222 }
223
224 pub fn reset_state(&mut self) {
225 self.encoder.reset_state();
226 self.encoder_transformer.reset_state();
227 self.decoder.reset_state();
228 self.decoder_transformer.reset_state();
229 self.upsample.reset_state();
230 self.downsample.reset_state();
231 }
232
233 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
234 self.encoder_transformer.reset_batch_idx(batch_idx, batch_size)?;
235 self.encoder_transformer.reset_batch_idx(batch_idx, batch_size)?;
236 self.encoder.reset_batch_idx(batch_idx, batch_size)?;
237 self.decoder.reset_batch_idx(batch_idx, batch_size)?;
238 self.upsample.reset_batch_idx(batch_idx, batch_size)?;
239 self.downsample.reset_batch_idx(batch_idx, batch_size)?;
240 Ok(())
241 }
242}
243
244pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Mimi> {
245 let vb =
246 unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
247 let cfg = Config::v0_1(num_codebooks);
248 let mimi = Mimi::new(cfg, vb)?;
249 Ok(mimi)
250}
251
252pub fn load_b(
253 batch_size: Option<usize>,
254 model_file: &str,
255 num_codebooks: Option<usize>,
256 dev: &Device,
257) -> Result<Mimi> {
258 let vb =
259 unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
260 let cfg = Config::v0_1(num_codebooks);
261 let mimi = Mimi::new_(batch_size, cfg, vb)?;
262 Ok(mimi)
263}