moshi_db/
lm.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use crate::nn::{linear, MaybeQuantizedEmbedding, MaybeQuantizedLinear, MaybeQuantizedVarBuilder};
6use crate::{
7    batched_transformer,
8    transformer::{self, CaSrc},
9    NormType, StreamMask,
10};
11use candle::{DType, Device, IndexOp, Module, Result, Tensor};
12
13thread_local! {
14    pub static VERBOSE: bool = {
15        match std::env::var("MIMI_VERBOSE") {
16            Ok(s) => {
17                !s.is_empty() && s != "0"
18            },
19            Err(_) => false,
20        }
21    }
22}
23#[derive(Debug, Clone, serde::Deserialize)]
24pub struct DepFormerConfig {
25    pub transformer: transformer::Config,
26    pub num_slices: usize,
27    pub low_rank_embeddings: Option<usize>,
28}
29
30#[derive(Debug, Clone, serde::Deserialize)]
31pub struct ExtraHeadsConfig {
32    pub num_heads: usize,
33    pub dim: usize,
34}
35
36#[derive(Debug, Clone, serde::Deserialize)]
37pub struct Config {
38    pub transformer: transformer::Config,
39    pub depformer: Option<DepFormerConfig>,
40    pub text_in_vocab_size: usize,
41    pub text_out_vocab_size: usize,
42    pub audio_vocab_size: usize,
43    pub audio_codebooks: usize,
44    pub conditioners: Option<crate::conditioner::Config>,
45    pub extra_heads: Option<ExtraHeadsConfig>,
46}
47
48impl Config {
49    fn depformer_cfg(num_slices: usize) -> DepFormerConfig {
50        let depformer_cfg = transformer::Config {
51            d_model: 1024,
52            num_heads: 16,
53            num_layers: 6,
54            dim_feedforward: 1024 * 4, // dim * hidden_scale
55            causal: true,
56            norm_first: true,
57            bias_ff: false,
58            bias_attn: false,
59            layer_scale: None,
60            context: num_slices,
61            max_period: 10000,
62            use_conv_block: false,
63            use_conv_bias: true,
64            cross_attention: None,
65            gating: Some(candle_nn::Activation::Silu),
66            norm: NormType::RmsNorm,
67            positional_embedding: transformer::PositionalEmbedding::None,
68            conv_layout: false,
69            conv_kernel_size: 3,
70            kv_repeat: 1,
71            max_seq_len: 4096,
72            shared_cross_attn: false,
73        };
74        DepFormerConfig { num_slices, transformer: depformer_cfg, low_rank_embeddings: None }
75    }
76
77    // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/af78657c/outputs/hyperparams.json
78    // Update 2024-03-19: Sin embeddings -> None, RmsNorm fix, scale factor 4.125
79    // Update 2024-05-02: split text_vocab_size into text_in_vocab_size and text_out_vocab_size.
80    // embeddings.
81    pub fn v0_1() -> Self {
82        let lm_cfg = transformer::Config {
83            d_model: 4096,
84            num_heads: 32,
85            num_layers: 32,
86            dim_feedforward: 4096 * 4, // dim * hidden_scale
87            causal: true,
88            norm_first: true,
89            bias_ff: false,
90            bias_attn: false,
91            layer_scale: None,
92            context: 3000,
93            max_period: 10000,
94            use_conv_block: false,
95            use_conv_bias: true,
96            cross_attention: None,
97            gating: Some(candle_nn::Activation::Silu),
98            norm: NormType::RmsNorm,
99            positional_embedding: transformer::PositionalEmbedding::Rope,
100            conv_layout: false,
101            conv_kernel_size: 3,
102            kv_repeat: 1,
103            max_seq_len: 4096,
104            shared_cross_attn: false,
105        };
106        Self {
107            transformer: lm_cfg,
108            depformer: Some(Self::depformer_cfg(8)),
109            audio_vocab_size: 2049,
110            text_in_vocab_size: 32001,
111            text_out_vocab_size: 32000,
112            audio_codebooks: 8,
113            conditioners: Default::default(),
114            extra_heads: None,
115        }
116    }
117
118    pub fn v0_1_vision() -> Self {
119        let lm_cfg = transformer::Config {
120            d_model: 4096,
121            num_heads: 32,
122            num_layers: 32,
123            dim_feedforward: 4096 * 4, // dim * hidden_scale
124            causal: true,
125            norm_first: true,
126            bias_ff: false,
127            bias_attn: false,
128            layer_scale: None,
129            context: 3000,
130            max_period: 10000,
131            use_conv_block: false,
132            use_conv_bias: true,
133            cross_attention: Some((
134                transformer::CrossAttentionGating::ConditionalGatedSigmoid,
135                NormType::RmsNorm,
136                None,
137            )),
138            gating: Some(candle_nn::Activation::Silu),
139            norm: NormType::RmsNorm,
140            positional_embedding: transformer::PositionalEmbedding::Rope,
141            conv_layout: false,
142            conv_kernel_size: 3,
143            kv_repeat: 1,
144            max_seq_len: 4096,
145            shared_cross_attn: true,
146        };
147        Self {
148            transformer: lm_cfg,
149            depformer: Some(Self::depformer_cfg(8)),
150            audio_vocab_size: 2049,
151            text_in_vocab_size: 32001,
152            text_out_vocab_size: 32000,
153            audio_codebooks: 8,
154            conditioners: Default::default(),
155            extra_heads: None,
156        }
157    }
158
159    pub fn v0_1_vision_streaming(num_slices: usize) -> Self {
160        let mut s = Self::v0_1_vision();
161        s.audio_codebooks = 16;
162        if let Some(depformer) = s.depformer.as_mut() {
163            depformer.num_slices = num_slices;
164            depformer.transformer.context = num_slices;
165        }
166        s
167    }
168
169    pub fn v0_1_streaming(num_slices: usize) -> Self {
170        let mut s = Self::v0_1();
171        s.audio_codebooks = 16;
172        if let Some(depformer) = s.depformer.as_mut() {
173            depformer.num_slices = num_slices;
174            depformer.transformer.context = num_slices;
175        }
176        s
177    }
178
179    pub fn v0_1_asr() -> Self {
180        let mut s = Self::v0_1();
181        s.audio_codebooks = 8;
182        if let Some(depformer) = s.depformer.as_mut() {
183            depformer.num_slices = 0;
184            depformer.transformer.context = 0;
185        }
186        s
187    }
188
189    // /lustre/scwpod02/client/kyutai/neilz/mimi_exp/xps/6bbe4692/outputs/hyperparams.json
190    pub fn tts_v0_1() -> Self {
191        let lm_cfg = transformer::Config {
192            d_model: 2048,
193            num_heads: 32,
194            num_layers: 48,
195            dim_feedforward: 4096 * 2, // dim * hidden_scale
196            causal: true,
197            norm_first: true,
198            bias_ff: false,
199            bias_attn: false,
200            layer_scale: None,
201            context: 4096,
202            max_period: 10000,
203            use_conv_block: false,
204            use_conv_bias: true,
205            cross_attention: Some((
206                transformer::CrossAttentionGating::Normal,
207                NormType::LayerNorm,
208                None,
209            )),
210            gating: None,
211            norm: NormType::LayerNorm,
212            positional_embedding: transformer::PositionalEmbedding::Rope,
213            conv_layout: false,
214            conv_kernel_size: 3,
215            kv_repeat: 1,
216            max_seq_len: 4096,
217            shared_cross_attn: false,
218        };
219        Self {
220            transformer: lm_cfg,
221            depformer: Some(Self::depformer_cfg(16)),
222            audio_vocab_size: 2050,
223            text_in_vocab_size: 32001,
224            text_out_vocab_size: 32001,
225            audio_codebooks: 16,
226            conditioners: Default::default(),
227            extra_heads: None,
228        }
229    }
230
231    // /lustre/scwpod02/client/kyutai-interns/tomlab/mimi_exp/xps/c879d080/.hydra/config.yaml
232    // /lustre/scwpod02/client/kyutai-interns/tomlab/mimi_exp/xps/41e5e07d/.hydra/config.yaml
233    pub fn s2s_v0_1() -> Self {
234        let lm_cfg = transformer::Config {
235            d_model: 2048,
236            num_heads: 16,
237            num_layers: 16,
238            dim_feedforward: 4096 * 2, // dim * hidden_scale
239            causal: true,
240            norm_first: true,
241            bias_ff: false,
242            bias_attn: false,
243            layer_scale: None,
244            context: 3000,
245            max_period: 10000,
246            use_conv_block: false,
247            use_conv_bias: true,
248            cross_attention: None,
249            gating: Some(candle_nn::Activation::Silu),
250            norm: NormType::RmsNorm,
251            positional_embedding: transformer::PositionalEmbedding::Rope,
252            conv_layout: false,
253            conv_kernel_size: 3,
254            kv_repeat: 1,
255            max_seq_len: 4096,
256            shared_cross_attn: false,
257        };
258        Self {
259            transformer: lm_cfg,
260            depformer: Some(Self::depformer_cfg(16)),
261            audio_vocab_size: 2049,
262            text_in_vocab_size: 48001,
263            text_out_vocab_size: 48000,
264            audio_codebooks: 16,
265            conditioners: Default::default(),
266            extra_heads: None,
267        }
268    }
269
270    pub fn s2s_v0_1_streaming(num_slices: usize) -> Self {
271        let mut s = Self::s2s_v0_1();
272        s.audio_codebooks = 16;
273        if let Some(depformer) = s.depformer.as_mut() {
274            depformer.num_slices = num_slices;
275            depformer.transformer.context = num_slices;
276        }
277        s
278    }
279
280    // /lustre/scwpod02/client/kyutai/neilz/mimi_exp/xps/33e476c7/.hydra/config.yaml
281    pub fn asr_v0_1_1b() -> Self {
282        let lm_cfg = transformer::Config {
283            d_model: 2048,
284            num_heads: 16,
285            num_layers: 16,
286            dim_feedforward: 2048 * 4,
287            causal: true,
288            norm_first: true,
289            bias_ff: false,
290            bias_attn: false,
291            layer_scale: None,
292            context: 750,
293            max_period: 100_000,
294            use_conv_block: false,
295            use_conv_bias: true,
296            cross_attention: None,
297            gating: Some(candle_nn::Activation::Silu),
298            norm: NormType::RmsNorm,
299            positional_embedding: transformer::PositionalEmbedding::Rope,
300            conv_layout: false,
301            conv_kernel_size: 3,
302            kv_repeat: 1,
303            max_seq_len: 4096,
304            shared_cross_attn: false,
305        };
306        Self {
307            transformer: lm_cfg,
308            depformer: None,
309            audio_vocab_size: 2049,
310            text_in_vocab_size: 48001,
311            text_out_vocab_size: 48000,
312            audio_codebooks: 8,
313            conditioners: Default::default(),
314            extra_heads: None,
315        }
316    }
317
318    pub fn asr_300m_202501() -> Self {
319        let lm_cfg = transformer::Config {
320            d_model: 1024,
321            num_heads: 8,
322            num_layers: 16,
323            dim_feedforward: 1024 * 4,
324            causal: true,
325            norm_first: true,
326            bias_ff: false,
327            bias_attn: false,
328            layer_scale: None,
329            context: 750,
330            max_period: 100_000,
331            use_conv_block: false,
332            use_conv_bias: true,
333            cross_attention: None,
334            gating: Some(candle_nn::Activation::Silu),
335            norm: NormType::RmsNorm,
336            positional_embedding: transformer::PositionalEmbedding::Rope,
337            conv_layout: false,
338            conv_kernel_size: 3,
339            kv_repeat: 1,
340            max_seq_len: 4096,
341            shared_cross_attn: false,
342        };
343        Self {
344            transformer: lm_cfg,
345            depformer: None,
346            audio_vocab_size: 2049,
347            text_in_vocab_size: 48001,
348            text_out_vocab_size: 48000,
349            audio_codebooks: 32,
350            conditioners: Default::default(),
351            extra_heads: None,
352        }
353    }
354
355    // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/d50593ae/.hydra/config.yaml
356    pub fn tts_202501() -> Self {
357        let lm_cfg = transformer::Config {
358            d_model: 2048,
359            num_heads: 32,
360            num_layers: 48,
361            dim_feedforward: 2048 * 4, // dim * hidden_scale
362            causal: true,
363            norm_first: true,
364            bias_ff: false,
365            bias_attn: false,
366            layer_scale: None,
367            context: 500,
368            max_period: 10000,
369            use_conv_block: false,
370            use_conv_bias: true,
371            cross_attention: Some((
372                transformer::CrossAttentionGating::Normal,
373                NormType::LayerNorm,
374                None,
375            )),
376            gating: Some(candle_nn::Activation::Silu),
377            norm: NormType::RmsNorm,
378            positional_embedding: transformer::PositionalEmbedding::Rope,
379            conv_layout: false,
380            conv_kernel_size: 3,
381            kv_repeat: 1,
382            max_seq_len: 4096,
383            shared_cross_attn: false,
384        };
385        Self {
386            transformer: lm_cfg,
387            depformer: Some(Self::depformer_cfg(32)),
388            audio_vocab_size: 2049,
389            text_in_vocab_size: 8001,
390            text_out_vocab_size: 8000,
391            audio_codebooks: 32,
392            conditioners: Default::default(),
393            extra_heads: None,
394        }
395    }
396
397    // /lustre/scwpod02/client/kyutai-interns/tomlab/mimi_exp/xps/1d426dfd/.hydra/config.yaml
398    pub fn s2s_2b_16rvq_202501() -> Self {
399        let lm_cfg = transformer::Config {
400            d_model: 2560,
401            num_heads: 20,
402            num_layers: 24,
403            dim_feedforward: 2560 * 4, // dim * hidden_scale
404            causal: true,
405            norm_first: true,
406            bias_ff: false,
407            bias_attn: false,
408            layer_scale: None,
409            context: 3000,
410            max_period: 100000,
411            use_conv_block: false,
412            use_conv_bias: true,
413            cross_attention: None,
414            gating: Some(candle_nn::Activation::Silu),
415            norm: NormType::RmsNorm,
416            positional_embedding: transformer::PositionalEmbedding::Rope,
417            conv_layout: false,
418            conv_kernel_size: 3,
419            kv_repeat: 1,
420            max_seq_len: 4096,
421            shared_cross_attn: false,
422        };
423        Self {
424            transformer: lm_cfg,
425            depformer: Some(Self::depformer_cfg(16)),
426            audio_vocab_size: 2049,
427            text_in_vocab_size: 48001,
428            text_out_vocab_size: 48000,
429            audio_codebooks: 32,
430            conditioners: Default::default(),
431            extra_heads: None,
432        }
433    }
434}
435
436#[derive(Debug, Clone)]
437struct LowRankEmbeddings {
438    embeddings: MaybeQuantizedEmbedding,
439    low_rank: Option<MaybeQuantizedLinear>,
440}
441
442impl LowRankEmbeddings {
443    fn new(
444        in_vocab_size: usize,
445        dim: usize,
446        low_rank_dim: Option<usize>,
447        vb: MaybeQuantizedVarBuilder,
448    ) -> Result<Self> {
449        let (low_rank, embeddings) = match low_rank_dim {
450            None => {
451                let embeddings = MaybeQuantizedEmbedding::new(in_vocab_size, dim, vb)?;
452                (None, embeddings)
453            }
454            Some(low_rank_dim) => {
455                let low_rank = linear(low_rank_dim, dim, false, vb.pp("low_rank"))?;
456                let embeddings = MaybeQuantizedEmbedding::new(in_vocab_size, low_rank_dim, vb)?;
457                (Some(low_rank), embeddings)
458            }
459        };
460        Ok(Self { embeddings, low_rank })
461    }
462}
463
464impl Module for LowRankEmbeddings {
465    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
466        let embs = xs.apply(&self.embeddings)?;
467        match self.low_rank.as_ref() {
468            None => Ok(embs),
469            Some(lr) => embs.apply(lr),
470        }
471    }
472}
473
474#[derive(Debug, Clone)]
475struct DepFormerSlice {
476    // There is no need for a streaming+batching mode here as the depformer does not have
477    // "persistent" caches.
478    transformer: transformer::StreamingTransformer,
479    // Note that the embedding for the first slice does not have the same dimension as the
480    // embedding for the other slices as it takes a text token as input rather than an audio token.
481    emb: LowRankEmbeddings,
482    linear_in: MaybeQuantizedLinear,  // depformer_in.{idx}
483    linear_out: MaybeQuantizedLinear, // linears.{idx}
484}
485
486impl DepFormerSlice {
487    fn new(
488        in_vocab_size: usize,
489        out_vocab_size: usize,
490        main_transformer_dim: usize,
491        cfg: &DepFormerConfig,
492        vb: MaybeQuantizedVarBuilder,
493    ) -> Result<Self> {
494        let dim = cfg.transformer.d_model;
495        let transformer =
496            transformer::StreamingTransformer::new(&cfg.transformer, vb.pp("transformer"))?;
497        let emb =
498            LowRankEmbeddings::new(in_vocab_size, dim, cfg.low_rank_embeddings, vb.pp("emb"))?;
499        let linear_in = linear(main_transformer_dim, dim, false, vb.pp("linear_in"))?;
500        let linear_out = linear(dim, out_vocab_size, false, vb.pp("linear_out"))?;
501        Ok(Self { transformer, emb, linear_in, linear_out })
502    }
503}
504
505#[derive(Debug, Clone)]
506pub struct DepFormer {
507    slices: Vec<DepFormerSlice>,
508}
509
510impl DepFormer {
511    pub fn new(
512        text_vocab_size: usize,
513        audio_vocab_size: usize,
514        main_transformer_dim: usize,
515        cfg: &DepFormerConfig,
516        vb: MaybeQuantizedVarBuilder,
517    ) -> Result<Self> {
518        let mut slices = Vec::with_capacity(cfg.num_slices);
519        for slice_idx in 0..cfg.num_slices {
520            let in_vs = if slice_idx == 0 { text_vocab_size } else { audio_vocab_size };
521            // The depformer cannot predict the audio padding token.
522            let slice = DepFormerSlice::new(
523                in_vs,
524                audio_vocab_size - 1, // The depformer cannot emit an audio padding token.
525                main_transformer_dim,
526                cfg,
527                vb.pp(slice_idx),
528            )?;
529            slices.push(slice)
530        }
531        Ok(Self { slices })
532    }
533
534    /// Run a transformer sampling step, getting a token id per codebook.
535    /// - `xs` is the previous layer hidden state.
536    pub fn sample(
537        &mut self,
538        xs: &Tensor,
539        text_token: Option<u32>,
540        forced_audio_tokens: &[Option<u32>],
541        lp: &mut candle_transformers::generation::LogitsProcessor,
542    ) -> Result<Vec<u32>> {
543        use crate::streaming::StreamingModule;
544        let dev = xs.device();
545        let mut tokens = Vec::with_capacity(self.slices.len());
546        let mut last_token = text_token;
547        for slice_idx in 0..self.slices.len() {
548            if slice_idx == 0 {
549                self.slices[slice_idx].transformer.reset_state();
550            } else {
551                let (lhs, rhs) = self.slices.split_at_mut(slice_idx);
552                rhs[0].transformer.copy_state(&lhs[slice_idx - 1].transformer)?
553            }
554            let slice = &mut self.slices[slice_idx];
555            let xs = slice.linear_in.forward(xs)?;
556            let xs = match last_token {
557                Some(last_token) => {
558                    let token_id = Tensor::from_vec(vec![last_token], (1, 1), dev)?;
559                    let token_emb = slice.emb.forward(&token_id)?;
560                    xs.broadcast_add(&token_emb)?
561                }
562                None => xs,
563            };
564            let xs = slice.transformer.forward(&xs)?;
565            let logits = xs.apply(&slice.linear_out)?;
566            let logits = match logits.dim(0)? {
567                1 => logits.i((0, 0))?,
568                b_size => candle::bail!("unexpected batch size {b_size}"),
569            };
570            let token = lp.sample(&logits)?;
571            if VERBOSE.with(|v| *v) {
572                println!("sampled {token} logits {slice_idx}:\n{logits}");
573            }
574            tokens.push(token);
575            let token_for_next_layer =
576                forced_audio_tokens.get(slice_idx).copied().flatten().unwrap_or(token);
577            last_token = Some(token_for_next_layer);
578        }
579        Ok(tokens)
580    }
581
582    // Sampling with classifier free guidance.
583    pub fn sample_cfg(
584        &mut self,
585        xs: &Tensor,
586        cfg_alpha: f64,
587        text_token: Option<u32>,
588        forced_audio_tokens: &[Option<u32>],
589        lp: &mut candle_transformers::generation::LogitsProcessor,
590    ) -> Result<Vec<u32>> {
591        use crate::streaming::StreamingModule;
592        let dev = xs.device();
593        let mut tokens = Vec::with_capacity(self.slices.len());
594        let mut last_token = text_token;
595        for slice_idx in 0..self.slices.len() {
596            if slice_idx == 0 {
597                self.slices[slice_idx].transformer.reset_state();
598            } else {
599                let (lhs, rhs) = self.slices.split_at_mut(slice_idx);
600                rhs[0].transformer.copy_state(&lhs[slice_idx - 1].transformer)?
601            }
602            let slice = &mut self.slices[slice_idx];
603            let xs = slice.linear_in.forward(xs)?;
604            let xs = match last_token {
605                Some(last_token) => {
606                    let token_id = Tensor::from_vec(vec![last_token], (1, 1), dev)?;
607                    let token_emb = slice.emb.forward(&token_id)?;
608                    xs.broadcast_add(&token_emb)?
609                }
610                None => xs,
611            };
612            let xs = slice.transformer.forward(&xs)?;
613            let logits = xs.apply(&slice.linear_out)?;
614            let logits = match logits.dim(0)? {
615                2 => ((logits.i((0, 0))? * cfg_alpha)? - (logits.i((1, 0))? * (cfg_alpha - 1.))?)?,
616                b_size => candle::bail!("unexpected batch size {b_size}"),
617            };
618            let token = lp.sample(&logits)?;
619            if VERBOSE.with(|v| *v) {
620                println!("sampled {token} logits {slice_idx}:\n{logits}");
621            }
622            tokens.push(token);
623            let token_for_next_layer =
624                forced_audio_tokens.get(slice_idx).copied().flatten().unwrap_or(token);
625            last_token = Some(token_for_next_layer);
626        }
627        Ok(tokens)
628    }
629}
630
631#[derive(Debug, Clone)]
632enum StreamingTransformer {
633    Normal(transformer::StreamingTransformer),
634    Batched(batched_transformer::StreamingTransformer),
635}
636
637impl crate::StreamingModule for StreamingTransformer {
638    fn reset_state(&mut self) {
639        match self {
640            StreamingTransformer::Normal(t) => t.reset_state(),
641            StreamingTransformer::Batched(t) => t.reset_state(),
642        }
643    }
644
645    fn step(
646        &mut self,
647        xs: &crate::StreamTensor,
648        mask: &crate::StreamMask,
649    ) -> Result<crate::StreamTensor> {
650        match self {
651            StreamingTransformer::Normal(t) => t.step(xs, mask),
652            StreamingTransformer::Batched(t) => t.step(xs, mask),
653        }
654    }
655}
656
657impl StreamingTransformer {
658    fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
659        match self {
660            StreamingTransformer::Normal(t) => t.reset_batch_idx(batch_idx, batch_size),
661            StreamingTransformer::Batched(t) => t.reset_batch_idx(batch_idx),
662        }
663    }
664
665    fn maybe_precompute_ca_kv(&self, ca_src: Option<CaSrc>) -> Result<Option<CaSrc>> {
666        match self {
667            StreamingTransformer::Normal(t) => t.maybe_precompute_ca_kv(ca_src),
668            StreamingTransformer::Batched(t) => t.maybe_precompute_ca_kv(ca_src),
669        }
670    }
671
672    fn forward(&mut self, xs: &Tensor, m: &StreamMask) -> Result<Tensor> {
673        match self {
674            StreamingTransformer::Normal(t) => t.forward(xs),
675            StreamingTransformer::Batched(t) => t.forward(xs, m),
676        }
677    }
678
679    fn forward_ca(
680        &mut self,
681        xs: &Tensor,
682        ca_src: Option<&CaSrc>,
683        m: &StreamMask,
684    ) -> Result<Tensor> {
685        match self {
686            StreamingTransformer::Normal(t) => t.forward_ca(xs, ca_src),
687            StreamingTransformer::Batched(t) => t.forward_ca(xs, ca_src, m),
688        }
689    }
690}
691
692#[derive(Debug, Clone)]
693pub struct LmModel {
694    transformer: StreamingTransformer,
695    text_emb: MaybeQuantizedEmbedding,
696    audio_embs: Vec<MaybeQuantizedEmbedding>,
697    text_linear: MaybeQuantizedLinear,
698    out_norm: transformer::Norm,
699    depformer: Option<DepFormer>,
700    audio_vocab_size: usize,
701    text_in_vocab_size: usize,
702    condition_provider: Option<crate::conditioner::ConditionProvider>,
703    extra_heads: Vec<MaybeQuantizedLinear>,
704    dtype: DType,
705}
706
707impl LmModel {
708    pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
709        Self::new_(None, cfg, vb)
710    }
711
712    pub fn batched(batch_size: usize, cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
713        Self::new_(Some(batch_size), cfg, vb)
714    }
715
716    pub fn new_(
717        batch_size: Option<usize>,
718        cfg: &Config,
719        vb: MaybeQuantizedVarBuilder,
720    ) -> Result<Self> {
721        let d_model = cfg.transformer.d_model;
722        let depformer = match &cfg.depformer {
723            None => None,
724            Some(depformer_cfg) => {
725                let depformer = DepFormer::new(
726                    cfg.text_in_vocab_size,
727                    cfg.audio_vocab_size,
728                    d_model,
729                    depformer_cfg,
730                    vb.pp("depformer"),
731                )?;
732                Some(depformer)
733            }
734        };
735        let text_emb =
736            MaybeQuantizedEmbedding::new(cfg.text_in_vocab_size, d_model, vb.pp("text_emb"))?;
737        let out_norm = transformer::Norm::new(d_model, &cfg.transformer, vb.pp("out_norm"))?;
738        let text_linear = linear(d_model, cfg.text_out_vocab_size, false, vb.pp("text_linear"))?;
739        let transformer = match batch_size {
740            None => {
741                let transformer =
742                    transformer::StreamingTransformer::new(&cfg.transformer, vb.pp("transformer"))?;
743                StreamingTransformer::Normal(transformer)
744            }
745            Some(batch_size) => {
746                let transformer = batched_transformer::StreamingTransformer::new(
747                    batch_size,
748                    &cfg.transformer,
749                    vb.pp("transformer"),
750                )?;
751                StreamingTransformer::Batched(transformer)
752            }
753        };
754        let vb_e = vb.pp("emb");
755        let mut audio_embs = Vec::with_capacity(cfg.audio_codebooks);
756        for i in 0..cfg.audio_codebooks {
757            let emb = MaybeQuantizedEmbedding::new(cfg.audio_vocab_size, d_model, vb_e.pp(i))?;
758            audio_embs.push(emb)
759        }
760        let dtype = vb.dtype();
761        let condition_provider = match cfg.conditioners.as_ref() {
762            None => None,
763            Some(cfg) => {
764                let conditioners = crate::conditioner::ConditionProvider::new(
765                    d_model,
766                    cfg,
767                    vb.pp("condition_provider"),
768                )?;
769                Some(conditioners)
770            }
771        };
772        let mut extra_heads = vec![];
773        if let Some(ExtraHeadsConfig { num_heads, dim }) = cfg.extra_heads {
774            for i in 0..num_heads {
775                let extra_head = linear(d_model, dim, false, vb.pp("extra_heads").pp(i))?;
776                extra_heads.push(extra_head)
777            }
778        }
779        Ok(Self {
780            transformer,
781            text_emb,
782            text_linear,
783            audio_embs,
784            out_norm,
785            depformer,
786            text_in_vocab_size: cfg.text_in_vocab_size,
787            audio_vocab_size: cfg.audio_vocab_size,
788            condition_provider,
789            extra_heads,
790            dtype,
791        })
792    }
793
794    pub fn condition_provider(&self) -> Option<&crate::conditioner::ConditionProvider> {
795        self.condition_provider.as_ref()
796    }
797
798    pub fn reset_state(&mut self) {
799        use crate::streaming::StreamingModule;
800        self.transformer.reset_state()
801    }
802
803    pub fn in_audio_codebooks(&self) -> usize {
804        self.audio_embs.len()
805    }
806
807    pub fn audio_pad_token(&self) -> u32 {
808        self.audio_vocab_size as u32 - 1
809    }
810
811    pub fn text_start_token(&self) -> u32 {
812        self.text_in_vocab_size as u32 - 1
813    }
814
815    pub fn generated_audio_codebooks(&self) -> usize {
816        self.depformer.as_ref().map_or(0, |v| v.slices.len())
817    }
818
819    pub fn is_quantized(&self) -> bool {
820        match self.text_linear {
821            MaybeQuantizedLinear::Quantized(_) => true,
822            MaybeQuantizedLinear::Real(_) => false,
823        }
824    }
825
826    pub fn device(&self) -> &Device {
827        self.text_emb.embeddings().device()
828    }
829
830    pub fn dtype(&self) -> DType {
831        self.text_emb.embeddings().dtype()
832    }
833
834    pub fn forward(
835        &mut self,
836        text_ids: Option<Tensor>,
837        audio_ids: Vec<Option<Tensor>>,
838        mask: &StreamMask,
839    ) -> candle::Result<(Tensor, Tensor)> {
840        self.forward_cond(text_ids, audio_ids, None, mask)
841    }
842
843    pub fn extra_heads(&self, vs: &Tensor) -> Result<Vec<Tensor>> {
844        let mut extra_heads = Vec::with_capacity(self.extra_heads.len());
845        for extra_head in self.extra_heads.iter() {
846            let extra_head = vs.apply(extra_head)?;
847            extra_heads.push(extra_head)
848        }
849        Ok(extra_heads)
850    }
851
852    pub fn forward_cond(
853        &mut self,
854        text_ids: Option<Tensor>,
855        audio_ids: Vec<Option<Tensor>>,
856        conditions: Option<&crate::conditioner::Condition>,
857        mask: &StreamMask,
858    ) -> candle::Result<(Tensor, Tensor)> {
859        if VERBOSE.with(|v| *v) {
860            print!("text_ids ");
861            if let Some(text_ids) = text_ids.as_ref() {
862                let text_ids = text_ids.flatten_all()?.to_vec1::<u32>()?;
863                println!("{text_ids:?}");
864            } else {
865                println!("none")
866            }
867            print!("audio_ids ");
868            for audio_id in audio_ids.iter() {
869                if let Some(audio_id) = audio_id {
870                    let audio_id = audio_id.flatten_all()?.to_vec1::<u32>()?;
871                    print!(" {audio_id:?}");
872                } else {
873                    print!(" none")
874                }
875            }
876            println!();
877        }
878        let mut emb = match text_ids.as_ref() {
879            Some(text_ids) => text_ids.apply(&self.text_emb)?,
880            None => {
881                let device = self.text_emb.embeddings().device();
882                Tensor::zeros((1, 1, self.text_emb.hidden_size()?), self.dtype, device)?
883            }
884        };
885
886        for (audio_emb, audio_ids) in self.audio_embs.iter().zip(audio_ids.iter()) {
887            if let Some(audio_ids) = audio_ids {
888                let e = audio_ids.apply(audio_emb)?;
889                emb = (emb + e)?
890            }
891        }
892        if let Some(conditions) = conditions {
893            match conditions {
894                crate::conditioner::Condition::AddToInput(v) => emb = emb.broadcast_add(v)?,
895            }
896        }
897        let ys = self.transformer.forward(&emb, mask)?;
898        let ys = ys.apply(&self.out_norm)?;
899        let logits = ys.apply(&self.text_linear)?;
900        if VERBOSE.with(|v| *v) {
901            println!("logits:\n{logits}");
902        }
903        Ok((logits, ys))
904    }
905
906    pub fn maybe_precompute_ca_kv(&self, ca_src: Option<CaSrc>) -> Result<Option<CaSrc>> {
907        let ca_src = match ca_src {
908            None => None,
909            z => self.transformer.maybe_precompute_ca_kv(z)?,
910        };
911        Ok(ca_src)
912    }
913
914    pub fn forward_ca(
915        &mut self,
916        text_ids: Option<Tensor>,
917        audio_ids: Vec<Option<Tensor>>,
918        ca_src: &CaSrc,
919        conditions: Option<&crate::conditioner::Condition>,
920        mask: &StreamMask,
921    ) -> candle::Result<(Tensor, Tensor)> {
922        if VERBOSE.with(|v| *v) {
923            print!("text_ids ");
924            if let Some(text_ids) = text_ids.as_ref() {
925                let text_ids = text_ids.flatten_all()?.to_vec1::<u32>()?;
926                println!("{text_ids:?}");
927            } else {
928                println!("none")
929            }
930            print!("audio_ids ");
931            for audio_id in audio_ids.iter() {
932                if let Some(audio_id) = audio_id {
933                    let audio_id = audio_id.flatten_all()?.to_vec1::<u32>()?;
934                    print!(" {audio_id:?}");
935                } else {
936                    print!(" none")
937                }
938            }
939            println!();
940        }
941        let b_size = match ca_src {
942            CaSrc::KeysValues((cak, _)) => cak.dim(0)?,
943            CaSrc::Tokens(catoks) => catoks.dim(0)?,
944        };
945        let mut emb = match text_ids {
946            Some(text_ids) => text_ids.apply(&self.text_emb)?,
947            None => {
948                let device = self.text_emb.embeddings().device();
949                Tensor::zeros((b_size, 1, self.text_emb.hidden_size()?), self.dtype, device)?
950            }
951        };
952        for (audio_emb, audio_ids) in self.audio_embs.iter().zip(audio_ids.iter()) {
953            if let Some(audio_ids) = audio_ids {
954                let e = audio_ids.apply(audio_emb)?;
955                emb = emb.broadcast_add(&e)?
956            }
957        }
958        if let Some(conditions) = conditions {
959            match conditions {
960                crate::conditioner::Condition::AddToInput(v) => emb = emb.broadcast_add(v)?,
961            }
962        }
963        let ys = self.transformer.forward_ca(&emb, Some(ca_src), mask)?;
964        let ys = ys.apply(&self.out_norm)?;
965        let logits = ys.apply(&self.text_linear)?;
966        Ok((logits, ys))
967    }
968
969    pub fn depformer_sample(
970        &mut self,
971        xs: &Tensor,
972        text_token: Option<u32>,
973        forced_audio_tokens: &[Option<u32>],
974        lp: &mut candle_transformers::generation::LogitsProcessor,
975    ) -> Result<Option<Vec<u32>>> {
976        let sample = match self.depformer.as_mut() {
977            None => None,
978            Some(m) => {
979                let sample = m.sample(xs, text_token, forced_audio_tokens, lp)?;
980                Some(sample)
981            }
982        };
983        Ok(sample)
984    }
985
986    pub fn depformer_sample_cfg(
987        &mut self,
988        xs: &Tensor,
989        cfg_alpha: f64,
990        text_token: Option<u32>,
991        forced_audio_tokens: &[Option<u32>],
992        lp: &mut candle_transformers::generation::LogitsProcessor,
993    ) -> Result<Option<Vec<u32>>> {
994        let sample = match self.depformer.as_mut() {
995            None => None,
996            Some(m) => {
997                let sample = m.sample_cfg(xs, cfg_alpha, text_token, forced_audio_tokens, lp)?;
998                Some(sample)
999            }
1000        };
1001        Ok(sample)
1002    }
1003
1004    pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
1005        self.transformer.reset_batch_idx(batch_idx, batch_size)
1006    }
1007}
1008
1009pub fn load_lm_model<P: AsRef<std::path::Path>>(
1010    cfg: Config,
1011    model_file: P,
1012    dtype: DType,
1013    dev: &Device,
1014) -> Result<LmModel> {
1015    let quantized = model_file.as_ref().extension().is_some_and(|v| v == "gguf");
1016    let vb = if quantized {
1017        MaybeQuantizedVarBuilder::Quantized(
1018            candle_transformers::quantized_var_builder::VarBuilder::from_gguf(model_file, dev)?,
1019        )
1020    } else {
1021        unsafe {
1022            MaybeQuantizedVarBuilder::Real(candle_nn::VarBuilder::from_mmaped_safetensors(
1023                &[model_file],
1024                dtype,
1025                dev,
1026            )?)
1027        }
1028    };
1029    let model = LmModel::new(&cfg, vb)?;
1030    Ok(model)
1031}
1032
1033pub fn load<P: AsRef<std::path::Path>>(
1034    model_file: P,
1035    dtype: DType,
1036    dev: &Device,
1037) -> Result<LmModel> {
1038    let cfg = Config::v0_1();
1039    load_lm_model(cfg, model_file, dtype, dev)
1040}
1041
1042pub fn load_streaming<P: AsRef<std::path::Path>>(
1043    model_file: P,
1044    dtype: DType,
1045    dev: &Device,
1046) -> Result<LmModel> {
1047    let cfg = Config::v0_1_streaming(8);
1048    load_lm_model(cfg, model_file, dtype, dev)
1049}
1050
1051pub fn load_streaming_both_ways<P: AsRef<std::path::Path>>(
1052    model_file: P,
1053    dtype: DType,
1054    dev: &Device,
1055) -> Result<LmModel> {
1056    let cfg = Config::v0_1_streaming(16);
1057    load_lm_model(cfg, model_file, dtype, dev)
1058}
1059
1060pub fn load_vision<P: AsRef<std::path::Path>>(
1061    model_file: P,
1062    override_cross_attention_gating: Option<transformer::CrossAttentionGating>,
1063    override_cross_attention_in_dim: Option<usize>,
1064    dtype: DType,
1065    dev: &Device,
1066) -> Result<LmModel> {
1067    // load_vision allows for overriding some hyperparams of the lm from the main config file
1068    let mut cfg = Config::v0_1_vision_streaming(8);
1069    cfg.transformer.cross_attention = override_cross_attention_gating
1070        .map(|v| (v, cfg.transformer.norm, override_cross_attention_in_dim));
1071    load_lm_model(cfg, model_file, dtype, dev)
1072}
1073
1074pub fn load_s2s<P: AsRef<std::path::Path>>(
1075    model_file: P,
1076    dtype: DType,
1077    dev: &Device,
1078) -> Result<LmModel> {
1079    let cfg = Config::s2s_2b_16rvq_202501();
1080    load_lm_model(cfg, model_file, dtype, dev)
1081}
1082
1083pub fn load_asr<P: AsRef<std::path::Path>>(
1084    model_file: P,
1085    dtype: DType,
1086    dev: &Device,
1087) -> Result<LmModel> {
1088    let cfg = Config::asr_v0_1_1b();
1089    load_lm_model(cfg, model_file, dtype, dev)
1090}
1091
1092pub struct ForcedAudioTokens {
1093    acoustic_delay: usize,
1094    // Tokens that are teacher forced before the acoustic delay.
1095    pre_delay_tokens: Vec<Option<u32>>,
1096}
1097
1098impl ForcedAudioTokens {
1099    pub fn new(acoustic_delay: usize, audio_pad_token: u32, stream_codebooks: &[usize]) -> Self {
1100        let mut pre_delay_tokens = vec![];
1101        for codebooks in stream_codebooks.iter() {
1102            for c in 0..*codebooks {
1103                let token = if c == 0 { None } else { Some(audio_pad_token) };
1104                pre_delay_tokens.push(token);
1105            }
1106        }
1107        Self { acoustic_delay, pre_delay_tokens }
1108    }
1109
1110    pub fn forced_tokens(&self, step_idx: usize) -> &[Option<u32>] {
1111        if step_idx < self.acoustic_delay {
1112            &self.pre_delay_tokens
1113        } else {
1114            &[]
1115        }
1116    }
1117}