1use crate::nn::{linear, MaybeQuantizedEmbedding, MaybeQuantizedLinear, MaybeQuantizedVarBuilder};
6use crate::{
7 batched_transformer,
8 transformer::{self, CaSrc},
9 NormType, StreamMask,
10};
11use candle::{DType, Device, IndexOp, Module, Result, Tensor};
12
13thread_local! {
14 pub static VERBOSE: bool = {
15 match std::env::var("MIMI_VERBOSE") {
16 Ok(s) => {
17 !s.is_empty() && s != "0"
18 },
19 Err(_) => false,
20 }
21 }
22}
23#[derive(Debug, Clone, serde::Deserialize)]
24pub struct DepFormerConfig {
25 pub transformer: transformer::Config,
26 pub num_slices: usize,
27 pub low_rank_embeddings: Option<usize>,
28}
29
30#[derive(Debug, Clone, serde::Deserialize)]
31pub struct ExtraHeadsConfig {
32 pub num_heads: usize,
33 pub dim: usize,
34}
35
36#[derive(Debug, Clone, serde::Deserialize)]
37pub struct Config {
38 pub transformer: transformer::Config,
39 pub depformer: Option<DepFormerConfig>,
40 pub text_in_vocab_size: usize,
41 pub text_out_vocab_size: usize,
42 pub audio_vocab_size: usize,
43 pub audio_codebooks: usize,
44 pub conditioners: Option<crate::conditioner::Config>,
45 pub extra_heads: Option<ExtraHeadsConfig>,
46}
47
48impl Config {
49 fn depformer_cfg(num_slices: usize) -> DepFormerConfig {
50 let depformer_cfg = transformer::Config {
51 d_model: 1024,
52 num_heads: 16,
53 num_layers: 6,
54 dim_feedforward: 1024 * 4, causal: true,
56 norm_first: true,
57 bias_ff: false,
58 bias_attn: false,
59 layer_scale: None,
60 context: num_slices,
61 max_period: 10000,
62 use_conv_block: false,
63 use_conv_bias: true,
64 cross_attention: None,
65 gating: Some(candle_nn::Activation::Silu),
66 norm: NormType::RmsNorm,
67 positional_embedding: transformer::PositionalEmbedding::None,
68 conv_layout: false,
69 conv_kernel_size: 3,
70 kv_repeat: 1,
71 max_seq_len: 4096,
72 shared_cross_attn: false,
73 };
74 DepFormerConfig { num_slices, transformer: depformer_cfg, low_rank_embeddings: None }
75 }
76
77 pub fn v0_1() -> Self {
82 let lm_cfg = transformer::Config {
83 d_model: 4096,
84 num_heads: 32,
85 num_layers: 32,
86 dim_feedforward: 4096 * 4, causal: true,
88 norm_first: true,
89 bias_ff: false,
90 bias_attn: false,
91 layer_scale: None,
92 context: 3000,
93 max_period: 10000,
94 use_conv_block: false,
95 use_conv_bias: true,
96 cross_attention: None,
97 gating: Some(candle_nn::Activation::Silu),
98 norm: NormType::RmsNorm,
99 positional_embedding: transformer::PositionalEmbedding::Rope,
100 conv_layout: false,
101 conv_kernel_size: 3,
102 kv_repeat: 1,
103 max_seq_len: 4096,
104 shared_cross_attn: false,
105 };
106 Self {
107 transformer: lm_cfg,
108 depformer: Some(Self::depformer_cfg(8)),
109 audio_vocab_size: 2049,
110 text_in_vocab_size: 32001,
111 text_out_vocab_size: 32000,
112 audio_codebooks: 8,
113 conditioners: Default::default(),
114 extra_heads: None,
115 }
116 }
117
118 pub fn v0_1_vision() -> Self {
119 let lm_cfg = transformer::Config {
120 d_model: 4096,
121 num_heads: 32,
122 num_layers: 32,
123 dim_feedforward: 4096 * 4, causal: true,
125 norm_first: true,
126 bias_ff: false,
127 bias_attn: false,
128 layer_scale: None,
129 context: 3000,
130 max_period: 10000,
131 use_conv_block: false,
132 use_conv_bias: true,
133 cross_attention: Some((
134 transformer::CrossAttentionGating::ConditionalGatedSigmoid,
135 NormType::RmsNorm,
136 None,
137 )),
138 gating: Some(candle_nn::Activation::Silu),
139 norm: NormType::RmsNorm,
140 positional_embedding: transformer::PositionalEmbedding::Rope,
141 conv_layout: false,
142 conv_kernel_size: 3,
143 kv_repeat: 1,
144 max_seq_len: 4096,
145 shared_cross_attn: true,
146 };
147 Self {
148 transformer: lm_cfg,
149 depformer: Some(Self::depformer_cfg(8)),
150 audio_vocab_size: 2049,
151 text_in_vocab_size: 32001,
152 text_out_vocab_size: 32000,
153 audio_codebooks: 8,
154 conditioners: Default::default(),
155 extra_heads: None,
156 }
157 }
158
159 pub fn v0_1_vision_streaming(num_slices: usize) -> Self {
160 let mut s = Self::v0_1_vision();
161 s.audio_codebooks = 16;
162 if let Some(depformer) = s.depformer.as_mut() {
163 depformer.num_slices = num_slices;
164 depformer.transformer.context = num_slices;
165 }
166 s
167 }
168
169 pub fn v0_1_streaming(num_slices: usize) -> Self {
170 let mut s = Self::v0_1();
171 s.audio_codebooks = 16;
172 if let Some(depformer) = s.depformer.as_mut() {
173 depformer.num_slices = num_slices;
174 depformer.transformer.context = num_slices;
175 }
176 s
177 }
178
179 pub fn v0_1_asr() -> Self {
180 let mut s = Self::v0_1();
181 s.audio_codebooks = 8;
182 if let Some(depformer) = s.depformer.as_mut() {
183 depformer.num_slices = 0;
184 depformer.transformer.context = 0;
185 }
186 s
187 }
188
189 pub fn tts_v0_1() -> Self {
191 let lm_cfg = transformer::Config {
192 d_model: 2048,
193 num_heads: 32,
194 num_layers: 48,
195 dim_feedforward: 4096 * 2, causal: true,
197 norm_first: true,
198 bias_ff: false,
199 bias_attn: false,
200 layer_scale: None,
201 context: 4096,
202 max_period: 10000,
203 use_conv_block: false,
204 use_conv_bias: true,
205 cross_attention: Some((
206 transformer::CrossAttentionGating::Normal,
207 NormType::LayerNorm,
208 None,
209 )),
210 gating: None,
211 norm: NormType::LayerNorm,
212 positional_embedding: transformer::PositionalEmbedding::Rope,
213 conv_layout: false,
214 conv_kernel_size: 3,
215 kv_repeat: 1,
216 max_seq_len: 4096,
217 shared_cross_attn: false,
218 };
219 Self {
220 transformer: lm_cfg,
221 depformer: Some(Self::depformer_cfg(16)),
222 audio_vocab_size: 2050,
223 text_in_vocab_size: 32001,
224 text_out_vocab_size: 32001,
225 audio_codebooks: 16,
226 conditioners: Default::default(),
227 extra_heads: None,
228 }
229 }
230
231 pub fn s2s_v0_1() -> Self {
234 let lm_cfg = transformer::Config {
235 d_model: 2048,
236 num_heads: 16,
237 num_layers: 16,
238 dim_feedforward: 4096 * 2, causal: true,
240 norm_first: true,
241 bias_ff: false,
242 bias_attn: false,
243 layer_scale: None,
244 context: 3000,
245 max_period: 10000,
246 use_conv_block: false,
247 use_conv_bias: true,
248 cross_attention: None,
249 gating: Some(candle_nn::Activation::Silu),
250 norm: NormType::RmsNorm,
251 positional_embedding: transformer::PositionalEmbedding::Rope,
252 conv_layout: false,
253 conv_kernel_size: 3,
254 kv_repeat: 1,
255 max_seq_len: 4096,
256 shared_cross_attn: false,
257 };
258 Self {
259 transformer: lm_cfg,
260 depformer: Some(Self::depformer_cfg(16)),
261 audio_vocab_size: 2049,
262 text_in_vocab_size: 48001,
263 text_out_vocab_size: 48000,
264 audio_codebooks: 16,
265 conditioners: Default::default(),
266 extra_heads: None,
267 }
268 }
269
270 pub fn s2s_v0_1_streaming(num_slices: usize) -> Self {
271 let mut s = Self::s2s_v0_1();
272 s.audio_codebooks = 16;
273 if let Some(depformer) = s.depformer.as_mut() {
274 depformer.num_slices = num_slices;
275 depformer.transformer.context = num_slices;
276 }
277 s
278 }
279
280 pub fn asr_v0_1_1b() -> Self {
282 let lm_cfg = transformer::Config {
283 d_model: 2048,
284 num_heads: 16,
285 num_layers: 16,
286 dim_feedforward: 2048 * 4,
287 causal: true,
288 norm_first: true,
289 bias_ff: false,
290 bias_attn: false,
291 layer_scale: None,
292 context: 750,
293 max_period: 100_000,
294 use_conv_block: false,
295 use_conv_bias: true,
296 cross_attention: None,
297 gating: Some(candle_nn::Activation::Silu),
298 norm: NormType::RmsNorm,
299 positional_embedding: transformer::PositionalEmbedding::Rope,
300 conv_layout: false,
301 conv_kernel_size: 3,
302 kv_repeat: 1,
303 max_seq_len: 4096,
304 shared_cross_attn: false,
305 };
306 Self {
307 transformer: lm_cfg,
308 depformer: None,
309 audio_vocab_size: 2049,
310 text_in_vocab_size: 48001,
311 text_out_vocab_size: 48000,
312 audio_codebooks: 8,
313 conditioners: Default::default(),
314 extra_heads: None,
315 }
316 }
317
318 pub fn asr_300m_202501() -> Self {
319 let lm_cfg = transformer::Config {
320 d_model: 1024,
321 num_heads: 8,
322 num_layers: 16,
323 dim_feedforward: 1024 * 4,
324 causal: true,
325 norm_first: true,
326 bias_ff: false,
327 bias_attn: false,
328 layer_scale: None,
329 context: 750,
330 max_period: 100_000,
331 use_conv_block: false,
332 use_conv_bias: true,
333 cross_attention: None,
334 gating: Some(candle_nn::Activation::Silu),
335 norm: NormType::RmsNorm,
336 positional_embedding: transformer::PositionalEmbedding::Rope,
337 conv_layout: false,
338 conv_kernel_size: 3,
339 kv_repeat: 1,
340 max_seq_len: 4096,
341 shared_cross_attn: false,
342 };
343 Self {
344 transformer: lm_cfg,
345 depformer: None,
346 audio_vocab_size: 2049,
347 text_in_vocab_size: 48001,
348 text_out_vocab_size: 48000,
349 audio_codebooks: 32,
350 conditioners: Default::default(),
351 extra_heads: None,
352 }
353 }
354
355 pub fn tts_202501() -> Self {
357 let lm_cfg = transformer::Config {
358 d_model: 2048,
359 num_heads: 32,
360 num_layers: 48,
361 dim_feedforward: 2048 * 4, causal: true,
363 norm_first: true,
364 bias_ff: false,
365 bias_attn: false,
366 layer_scale: None,
367 context: 500,
368 max_period: 10000,
369 use_conv_block: false,
370 use_conv_bias: true,
371 cross_attention: Some((
372 transformer::CrossAttentionGating::Normal,
373 NormType::LayerNorm,
374 None,
375 )),
376 gating: Some(candle_nn::Activation::Silu),
377 norm: NormType::RmsNorm,
378 positional_embedding: transformer::PositionalEmbedding::Rope,
379 conv_layout: false,
380 conv_kernel_size: 3,
381 kv_repeat: 1,
382 max_seq_len: 4096,
383 shared_cross_attn: false,
384 };
385 Self {
386 transformer: lm_cfg,
387 depformer: Some(Self::depformer_cfg(32)),
388 audio_vocab_size: 2049,
389 text_in_vocab_size: 8001,
390 text_out_vocab_size: 8000,
391 audio_codebooks: 32,
392 conditioners: Default::default(),
393 extra_heads: None,
394 }
395 }
396
397 pub fn s2s_2b_16rvq_202501() -> Self {
399 let lm_cfg = transformer::Config {
400 d_model: 2560,
401 num_heads: 20,
402 num_layers: 24,
403 dim_feedforward: 2560 * 4, causal: true,
405 norm_first: true,
406 bias_ff: false,
407 bias_attn: false,
408 layer_scale: None,
409 context: 3000,
410 max_period: 100000,
411 use_conv_block: false,
412 use_conv_bias: true,
413 cross_attention: None,
414 gating: Some(candle_nn::Activation::Silu),
415 norm: NormType::RmsNorm,
416 positional_embedding: transformer::PositionalEmbedding::Rope,
417 conv_layout: false,
418 conv_kernel_size: 3,
419 kv_repeat: 1,
420 max_seq_len: 4096,
421 shared_cross_attn: false,
422 };
423 Self {
424 transformer: lm_cfg,
425 depformer: Some(Self::depformer_cfg(16)),
426 audio_vocab_size: 2049,
427 text_in_vocab_size: 48001,
428 text_out_vocab_size: 48000,
429 audio_codebooks: 32,
430 conditioners: Default::default(),
431 extra_heads: None,
432 }
433 }
434}
435
436#[derive(Debug, Clone)]
437struct LowRankEmbeddings {
438 embeddings: MaybeQuantizedEmbedding,
439 low_rank: Option<MaybeQuantizedLinear>,
440}
441
442impl LowRankEmbeddings {
443 fn new(
444 in_vocab_size: usize,
445 dim: usize,
446 low_rank_dim: Option<usize>,
447 vb: MaybeQuantizedVarBuilder,
448 ) -> Result<Self> {
449 let (low_rank, embeddings) = match low_rank_dim {
450 None => {
451 let embeddings = MaybeQuantizedEmbedding::new(in_vocab_size, dim, vb)?;
452 (None, embeddings)
453 }
454 Some(low_rank_dim) => {
455 let low_rank = linear(low_rank_dim, dim, false, vb.pp("low_rank"))?;
456 let embeddings = MaybeQuantizedEmbedding::new(in_vocab_size, low_rank_dim, vb)?;
457 (Some(low_rank), embeddings)
458 }
459 };
460 Ok(Self { embeddings, low_rank })
461 }
462}
463
464impl Module for LowRankEmbeddings {
465 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
466 let embs = xs.apply(&self.embeddings)?;
467 match self.low_rank.as_ref() {
468 None => Ok(embs),
469 Some(lr) => embs.apply(lr),
470 }
471 }
472}
473
474#[derive(Debug, Clone)]
475struct DepFormerSlice {
476 transformer: transformer::StreamingTransformer,
479 emb: LowRankEmbeddings,
482 linear_in: MaybeQuantizedLinear, linear_out: MaybeQuantizedLinear, }
485
486impl DepFormerSlice {
487 fn new(
488 in_vocab_size: usize,
489 out_vocab_size: usize,
490 main_transformer_dim: usize,
491 cfg: &DepFormerConfig,
492 vb: MaybeQuantizedVarBuilder,
493 ) -> Result<Self> {
494 let dim = cfg.transformer.d_model;
495 let transformer =
496 transformer::StreamingTransformer::new(&cfg.transformer, vb.pp("transformer"))?;
497 let emb =
498 LowRankEmbeddings::new(in_vocab_size, dim, cfg.low_rank_embeddings, vb.pp("emb"))?;
499 let linear_in = linear(main_transformer_dim, dim, false, vb.pp("linear_in"))?;
500 let linear_out = linear(dim, out_vocab_size, false, vb.pp("linear_out"))?;
501 Ok(Self { transformer, emb, linear_in, linear_out })
502 }
503}
504
505#[derive(Debug, Clone)]
506pub struct DepFormer {
507 slices: Vec<DepFormerSlice>,
508}
509
510impl DepFormer {
511 pub fn new(
512 text_vocab_size: usize,
513 audio_vocab_size: usize,
514 main_transformer_dim: usize,
515 cfg: &DepFormerConfig,
516 vb: MaybeQuantizedVarBuilder,
517 ) -> Result<Self> {
518 let mut slices = Vec::with_capacity(cfg.num_slices);
519 for slice_idx in 0..cfg.num_slices {
520 let in_vs = if slice_idx == 0 { text_vocab_size } else { audio_vocab_size };
521 let slice = DepFormerSlice::new(
523 in_vs,
524 audio_vocab_size - 1, main_transformer_dim,
526 cfg,
527 vb.pp(slice_idx),
528 )?;
529 slices.push(slice)
530 }
531 Ok(Self { slices })
532 }
533
534 pub fn sample(
537 &mut self,
538 xs: &Tensor,
539 text_token: Option<u32>,
540 forced_audio_tokens: &[Option<u32>],
541 lp: &mut candle_transformers::generation::LogitsProcessor,
542 ) -> Result<Vec<u32>> {
543 use crate::streaming::StreamingModule;
544 let dev = xs.device();
545 let mut tokens = Vec::with_capacity(self.slices.len());
546 let mut last_token = text_token;
547 for slice_idx in 0..self.slices.len() {
548 if slice_idx == 0 {
549 self.slices[slice_idx].transformer.reset_state();
550 } else {
551 let (lhs, rhs) = self.slices.split_at_mut(slice_idx);
552 rhs[0].transformer.copy_state(&lhs[slice_idx - 1].transformer)?
553 }
554 let slice = &mut self.slices[slice_idx];
555 let xs = slice.linear_in.forward(xs)?;
556 let xs = match last_token {
557 Some(last_token) => {
558 let token_id = Tensor::from_vec(vec![last_token], (1, 1), dev)?;
559 let token_emb = slice.emb.forward(&token_id)?;
560 xs.broadcast_add(&token_emb)?
561 }
562 None => xs,
563 };
564 let xs = slice.transformer.forward(&xs)?;
565 let logits = xs.apply(&slice.linear_out)?;
566 let logits = match logits.dim(0)? {
567 1 => logits.i((0, 0))?,
568 b_size => candle::bail!("unexpected batch size {b_size}"),
569 };
570 let token = lp.sample(&logits)?;
571 if VERBOSE.with(|v| *v) {
572 println!("sampled {token} logits {slice_idx}:\n{logits}");
573 }
574 tokens.push(token);
575 let token_for_next_layer =
576 forced_audio_tokens.get(slice_idx).copied().flatten().unwrap_or(token);
577 last_token = Some(token_for_next_layer);
578 }
579 Ok(tokens)
580 }
581
582 pub fn sample_cfg(
584 &mut self,
585 xs: &Tensor,
586 cfg_alpha: f64,
587 text_token: Option<u32>,
588 forced_audio_tokens: &[Option<u32>],
589 lp: &mut candle_transformers::generation::LogitsProcessor,
590 ) -> Result<Vec<u32>> {
591 use crate::streaming::StreamingModule;
592 let dev = xs.device();
593 let mut tokens = Vec::with_capacity(self.slices.len());
594 let mut last_token = text_token;
595 for slice_idx in 0..self.slices.len() {
596 if slice_idx == 0 {
597 self.slices[slice_idx].transformer.reset_state();
598 } else {
599 let (lhs, rhs) = self.slices.split_at_mut(slice_idx);
600 rhs[0].transformer.copy_state(&lhs[slice_idx - 1].transformer)?
601 }
602 let slice = &mut self.slices[slice_idx];
603 let xs = slice.linear_in.forward(xs)?;
604 let xs = match last_token {
605 Some(last_token) => {
606 let token_id = Tensor::from_vec(vec![last_token], (1, 1), dev)?;
607 let token_emb = slice.emb.forward(&token_id)?;
608 xs.broadcast_add(&token_emb)?
609 }
610 None => xs,
611 };
612 let xs = slice.transformer.forward(&xs)?;
613 let logits = xs.apply(&slice.linear_out)?;
614 let logits = match logits.dim(0)? {
615 2 => ((logits.i((0, 0))? * cfg_alpha)? - (logits.i((1, 0))? * (cfg_alpha - 1.))?)?,
616 b_size => candle::bail!("unexpected batch size {b_size}"),
617 };
618 let token = lp.sample(&logits)?;
619 if VERBOSE.with(|v| *v) {
620 println!("sampled {token} logits {slice_idx}:\n{logits}");
621 }
622 tokens.push(token);
623 let token_for_next_layer =
624 forced_audio_tokens.get(slice_idx).copied().flatten().unwrap_or(token);
625 last_token = Some(token_for_next_layer);
626 }
627 Ok(tokens)
628 }
629}
630
631#[derive(Debug, Clone)]
632enum StreamingTransformer {
633 Normal(transformer::StreamingTransformer),
634 Batched(batched_transformer::StreamingTransformer),
635}
636
637impl crate::StreamingModule for StreamingTransformer {
638 fn reset_state(&mut self) {
639 match self {
640 StreamingTransformer::Normal(t) => t.reset_state(),
641 StreamingTransformer::Batched(t) => t.reset_state(),
642 }
643 }
644
645 fn step(
646 &mut self,
647 xs: &crate::StreamTensor,
648 mask: &crate::StreamMask,
649 ) -> Result<crate::StreamTensor> {
650 match self {
651 StreamingTransformer::Normal(t) => t.step(xs, mask),
652 StreamingTransformer::Batched(t) => t.step(xs, mask),
653 }
654 }
655}
656
657impl StreamingTransformer {
658 fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
659 match self {
660 StreamingTransformer::Normal(t) => t.reset_batch_idx(batch_idx, batch_size),
661 StreamingTransformer::Batched(t) => t.reset_batch_idx(batch_idx),
662 }
663 }
664
665 fn maybe_precompute_ca_kv(&self, ca_src: Option<CaSrc>) -> Result<Option<CaSrc>> {
666 match self {
667 StreamingTransformer::Normal(t) => t.maybe_precompute_ca_kv(ca_src),
668 StreamingTransformer::Batched(t) => t.maybe_precompute_ca_kv(ca_src),
669 }
670 }
671
672 fn forward(&mut self, xs: &Tensor, m: &StreamMask) -> Result<Tensor> {
673 match self {
674 StreamingTransformer::Normal(t) => t.forward(xs),
675 StreamingTransformer::Batched(t) => t.forward(xs, m),
676 }
677 }
678
679 fn forward_ca(
680 &mut self,
681 xs: &Tensor,
682 ca_src: Option<&CaSrc>,
683 m: &StreamMask,
684 ) -> Result<Tensor> {
685 match self {
686 StreamingTransformer::Normal(t) => t.forward_ca(xs, ca_src),
687 StreamingTransformer::Batched(t) => t.forward_ca(xs, ca_src, m),
688 }
689 }
690}
691
692#[derive(Debug, Clone)]
693pub struct LmModel {
694 transformer: StreamingTransformer,
695 text_emb: MaybeQuantizedEmbedding,
696 audio_embs: Vec<MaybeQuantizedEmbedding>,
697 text_linear: MaybeQuantizedLinear,
698 out_norm: transformer::Norm,
699 depformer: Option<DepFormer>,
700 audio_vocab_size: usize,
701 text_in_vocab_size: usize,
702 condition_provider: Option<crate::conditioner::ConditionProvider>,
703 extra_heads: Vec<MaybeQuantizedLinear>,
704 dtype: DType,
705}
706
707impl LmModel {
708 pub fn new(cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
709 Self::new_(None, cfg, vb)
710 }
711
712 pub fn batched(batch_size: usize, cfg: &Config, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
713 Self::new_(Some(batch_size), cfg, vb)
714 }
715
716 pub fn new_(
717 batch_size: Option<usize>,
718 cfg: &Config,
719 vb: MaybeQuantizedVarBuilder,
720 ) -> Result<Self> {
721 let d_model = cfg.transformer.d_model;
722 let depformer = match &cfg.depformer {
723 None => None,
724 Some(depformer_cfg) => {
725 let depformer = DepFormer::new(
726 cfg.text_in_vocab_size,
727 cfg.audio_vocab_size,
728 d_model,
729 depformer_cfg,
730 vb.pp("depformer"),
731 )?;
732 Some(depformer)
733 }
734 };
735 let text_emb =
736 MaybeQuantizedEmbedding::new(cfg.text_in_vocab_size, d_model, vb.pp("text_emb"))?;
737 let out_norm = transformer::Norm::new(d_model, &cfg.transformer, vb.pp("out_norm"))?;
738 let text_linear = linear(d_model, cfg.text_out_vocab_size, false, vb.pp("text_linear"))?;
739 let transformer = match batch_size {
740 None => {
741 let transformer =
742 transformer::StreamingTransformer::new(&cfg.transformer, vb.pp("transformer"))?;
743 StreamingTransformer::Normal(transformer)
744 }
745 Some(batch_size) => {
746 let transformer = batched_transformer::StreamingTransformer::new(
747 batch_size,
748 &cfg.transformer,
749 vb.pp("transformer"),
750 )?;
751 StreamingTransformer::Batched(transformer)
752 }
753 };
754 let vb_e = vb.pp("emb");
755 let mut audio_embs = Vec::with_capacity(cfg.audio_codebooks);
756 for i in 0..cfg.audio_codebooks {
757 let emb = MaybeQuantizedEmbedding::new(cfg.audio_vocab_size, d_model, vb_e.pp(i))?;
758 audio_embs.push(emb)
759 }
760 let dtype = vb.dtype();
761 let condition_provider = match cfg.conditioners.as_ref() {
762 None => None,
763 Some(cfg) => {
764 let conditioners = crate::conditioner::ConditionProvider::new(
765 d_model,
766 cfg,
767 vb.pp("condition_provider"),
768 )?;
769 Some(conditioners)
770 }
771 };
772 let mut extra_heads = vec![];
773 if let Some(ExtraHeadsConfig { num_heads, dim }) = cfg.extra_heads {
774 for i in 0..num_heads {
775 let extra_head = linear(d_model, dim, false, vb.pp("extra_heads").pp(i))?;
776 extra_heads.push(extra_head)
777 }
778 }
779 Ok(Self {
780 transformer,
781 text_emb,
782 text_linear,
783 audio_embs,
784 out_norm,
785 depformer,
786 text_in_vocab_size: cfg.text_in_vocab_size,
787 audio_vocab_size: cfg.audio_vocab_size,
788 condition_provider,
789 extra_heads,
790 dtype,
791 })
792 }
793
794 pub fn condition_provider(&self) -> Option<&crate::conditioner::ConditionProvider> {
795 self.condition_provider.as_ref()
796 }
797
798 pub fn reset_state(&mut self) {
799 use crate::streaming::StreamingModule;
800 self.transformer.reset_state()
801 }
802
803 pub fn in_audio_codebooks(&self) -> usize {
804 self.audio_embs.len()
805 }
806
807 pub fn audio_pad_token(&self) -> u32 {
808 self.audio_vocab_size as u32 - 1
809 }
810
811 pub fn text_start_token(&self) -> u32 {
812 self.text_in_vocab_size as u32 - 1
813 }
814
815 pub fn generated_audio_codebooks(&self) -> usize {
816 self.depformer.as_ref().map_or(0, |v| v.slices.len())
817 }
818
819 pub fn is_quantized(&self) -> bool {
820 match self.text_linear {
821 MaybeQuantizedLinear::Quantized(_) => true,
822 MaybeQuantizedLinear::Real(_) => false,
823 }
824 }
825
826 pub fn device(&self) -> &Device {
827 self.text_emb.embeddings().device()
828 }
829
830 pub fn dtype(&self) -> DType {
831 self.text_emb.embeddings().dtype()
832 }
833
834 pub fn forward(
835 &mut self,
836 text_ids: Option<Tensor>,
837 audio_ids: Vec<Option<Tensor>>,
838 mask: &StreamMask,
839 ) -> candle::Result<(Tensor, Tensor)> {
840 self.forward_cond(text_ids, audio_ids, None, mask)
841 }
842
843 pub fn extra_heads(&self, vs: &Tensor) -> Result<Vec<Tensor>> {
844 let mut extra_heads = Vec::with_capacity(self.extra_heads.len());
845 for extra_head in self.extra_heads.iter() {
846 let extra_head = vs.apply(extra_head)?;
847 extra_heads.push(extra_head)
848 }
849 Ok(extra_heads)
850 }
851
852 pub fn forward_cond(
853 &mut self,
854 text_ids: Option<Tensor>,
855 audio_ids: Vec<Option<Tensor>>,
856 conditions: Option<&crate::conditioner::Condition>,
857 mask: &StreamMask,
858 ) -> candle::Result<(Tensor, Tensor)> {
859 if VERBOSE.with(|v| *v) {
860 print!("text_ids ");
861 if let Some(text_ids) = text_ids.as_ref() {
862 let text_ids = text_ids.flatten_all()?.to_vec1::<u32>()?;
863 println!("{text_ids:?}");
864 } else {
865 println!("none")
866 }
867 print!("audio_ids ");
868 for audio_id in audio_ids.iter() {
869 if let Some(audio_id) = audio_id {
870 let audio_id = audio_id.flatten_all()?.to_vec1::<u32>()?;
871 print!(" {audio_id:?}");
872 } else {
873 print!(" none")
874 }
875 }
876 println!();
877 }
878 let mut emb = match text_ids.as_ref() {
879 Some(text_ids) => text_ids.apply(&self.text_emb)?,
880 None => {
881 let device = self.text_emb.embeddings().device();
882 Tensor::zeros((1, 1, self.text_emb.hidden_size()?), self.dtype, device)?
883 }
884 };
885
886 for (audio_emb, audio_ids) in self.audio_embs.iter().zip(audio_ids.iter()) {
887 if let Some(audio_ids) = audio_ids {
888 let e = audio_ids.apply(audio_emb)?;
889 emb = (emb + e)?
890 }
891 }
892 if let Some(conditions) = conditions {
893 match conditions {
894 crate::conditioner::Condition::AddToInput(v) => emb = emb.broadcast_add(v)?,
895 }
896 }
897 let ys = self.transformer.forward(&emb, mask)?;
898 let ys = ys.apply(&self.out_norm)?;
899 let logits = ys.apply(&self.text_linear)?;
900 if VERBOSE.with(|v| *v) {
901 println!("logits:\n{logits}");
902 }
903 Ok((logits, ys))
904 }
905
906 pub fn maybe_precompute_ca_kv(&self, ca_src: Option<CaSrc>) -> Result<Option<CaSrc>> {
907 let ca_src = match ca_src {
908 None => None,
909 z => self.transformer.maybe_precompute_ca_kv(z)?,
910 };
911 Ok(ca_src)
912 }
913
914 pub fn forward_ca(
915 &mut self,
916 text_ids: Option<Tensor>,
917 audio_ids: Vec<Option<Tensor>>,
918 ca_src: &CaSrc,
919 conditions: Option<&crate::conditioner::Condition>,
920 mask: &StreamMask,
921 ) -> candle::Result<(Tensor, Tensor)> {
922 if VERBOSE.with(|v| *v) {
923 print!("text_ids ");
924 if let Some(text_ids) = text_ids.as_ref() {
925 let text_ids = text_ids.flatten_all()?.to_vec1::<u32>()?;
926 println!("{text_ids:?}");
927 } else {
928 println!("none")
929 }
930 print!("audio_ids ");
931 for audio_id in audio_ids.iter() {
932 if let Some(audio_id) = audio_id {
933 let audio_id = audio_id.flatten_all()?.to_vec1::<u32>()?;
934 print!(" {audio_id:?}");
935 } else {
936 print!(" none")
937 }
938 }
939 println!();
940 }
941 let b_size = match ca_src {
942 CaSrc::KeysValues((cak, _)) => cak.dim(0)?,
943 CaSrc::Tokens(catoks) => catoks.dim(0)?,
944 };
945 let mut emb = match text_ids {
946 Some(text_ids) => text_ids.apply(&self.text_emb)?,
947 None => {
948 let device = self.text_emb.embeddings().device();
949 Tensor::zeros((b_size, 1, self.text_emb.hidden_size()?), self.dtype, device)?
950 }
951 };
952 for (audio_emb, audio_ids) in self.audio_embs.iter().zip(audio_ids.iter()) {
953 if let Some(audio_ids) = audio_ids {
954 let e = audio_ids.apply(audio_emb)?;
955 emb = emb.broadcast_add(&e)?
956 }
957 }
958 if let Some(conditions) = conditions {
959 match conditions {
960 crate::conditioner::Condition::AddToInput(v) => emb = emb.broadcast_add(v)?,
961 }
962 }
963 let ys = self.transformer.forward_ca(&emb, Some(ca_src), mask)?;
964 let ys = ys.apply(&self.out_norm)?;
965 let logits = ys.apply(&self.text_linear)?;
966 Ok((logits, ys))
967 }
968
969 pub fn depformer_sample(
970 &mut self,
971 xs: &Tensor,
972 text_token: Option<u32>,
973 forced_audio_tokens: &[Option<u32>],
974 lp: &mut candle_transformers::generation::LogitsProcessor,
975 ) -> Result<Option<Vec<u32>>> {
976 let sample = match self.depformer.as_mut() {
977 None => None,
978 Some(m) => {
979 let sample = m.sample(xs, text_token, forced_audio_tokens, lp)?;
980 Some(sample)
981 }
982 };
983 Ok(sample)
984 }
985
986 pub fn depformer_sample_cfg(
987 &mut self,
988 xs: &Tensor,
989 cfg_alpha: f64,
990 text_token: Option<u32>,
991 forced_audio_tokens: &[Option<u32>],
992 lp: &mut candle_transformers::generation::LogitsProcessor,
993 ) -> Result<Option<Vec<u32>>> {
994 let sample = match self.depformer.as_mut() {
995 None => None,
996 Some(m) => {
997 let sample = m.sample_cfg(xs, cfg_alpha, text_token, forced_audio_tokens, lp)?;
998 Some(sample)
999 }
1000 };
1001 Ok(sample)
1002 }
1003
1004 pub fn reset_batch_idx(&mut self, batch_idx: usize, batch_size: usize) -> Result<()> {
1005 self.transformer.reset_batch_idx(batch_idx, batch_size)
1006 }
1007}
1008
1009pub fn load_lm_model<P: AsRef<std::path::Path>>(
1010 cfg: Config,
1011 model_file: P,
1012 dtype: DType,
1013 dev: &Device,
1014) -> Result<LmModel> {
1015 let quantized = model_file.as_ref().extension().is_some_and(|v| v == "gguf");
1016 let vb = if quantized {
1017 MaybeQuantizedVarBuilder::Quantized(
1018 candle_transformers::quantized_var_builder::VarBuilder::from_gguf(model_file, dev)?,
1019 )
1020 } else {
1021 unsafe {
1022 MaybeQuantizedVarBuilder::Real(candle_nn::VarBuilder::from_mmaped_safetensors(
1023 &[model_file],
1024 dtype,
1025 dev,
1026 )?)
1027 }
1028 };
1029 let model = LmModel::new(&cfg, vb)?;
1030 Ok(model)
1031}
1032
1033pub fn load<P: AsRef<std::path::Path>>(
1034 model_file: P,
1035 dtype: DType,
1036 dev: &Device,
1037) -> Result<LmModel> {
1038 let cfg = Config::v0_1();
1039 load_lm_model(cfg, model_file, dtype, dev)
1040}
1041
1042pub fn load_streaming<P: AsRef<std::path::Path>>(
1043 model_file: P,
1044 dtype: DType,
1045 dev: &Device,
1046) -> Result<LmModel> {
1047 let cfg = Config::v0_1_streaming(8);
1048 load_lm_model(cfg, model_file, dtype, dev)
1049}
1050
1051pub fn load_streaming_both_ways<P: AsRef<std::path::Path>>(
1052 model_file: P,
1053 dtype: DType,
1054 dev: &Device,
1055) -> Result<LmModel> {
1056 let cfg = Config::v0_1_streaming(16);
1057 load_lm_model(cfg, model_file, dtype, dev)
1058}
1059
1060pub fn load_vision<P: AsRef<std::path::Path>>(
1061 model_file: P,
1062 override_cross_attention_gating: Option<transformer::CrossAttentionGating>,
1063 override_cross_attention_in_dim: Option<usize>,
1064 dtype: DType,
1065 dev: &Device,
1066) -> Result<LmModel> {
1067 let mut cfg = Config::v0_1_vision_streaming(8);
1069 cfg.transformer.cross_attention = override_cross_attention_gating
1070 .map(|v| (v, cfg.transformer.norm, override_cross_attention_in_dim));
1071 load_lm_model(cfg, model_file, dtype, dev)
1072}
1073
1074pub fn load_s2s<P: AsRef<std::path::Path>>(
1075 model_file: P,
1076 dtype: DType,
1077 dev: &Device,
1078) -> Result<LmModel> {
1079 let cfg = Config::s2s_2b_16rvq_202501();
1080 load_lm_model(cfg, model_file, dtype, dev)
1081}
1082
1083pub fn load_asr<P: AsRef<std::path::Path>>(
1084 model_file: P,
1085 dtype: DType,
1086 dev: &Device,
1087) -> Result<LmModel> {
1088 let cfg = Config::asr_v0_1_1b();
1089 load_lm_model(cfg, model_file, dtype, dev)
1090}
1091
1092pub struct ForcedAudioTokens {
1093 acoustic_delay: usize,
1094 pre_delay_tokens: Vec<Option<u32>>,
1096}
1097
1098impl ForcedAudioTokens {
1099 pub fn new(acoustic_delay: usize, audio_pad_token: u32, stream_codebooks: &[usize]) -> Self {
1100 let mut pre_delay_tokens = vec![];
1101 for codebooks in stream_codebooks.iter() {
1102 for c in 0..*codebooks {
1103 let token = if c == 0 { None } else { Some(audio_pad_token) };
1104 pre_delay_tokens.push(token);
1105 }
1106 }
1107 Self { acoustic_delay, pre_delay_tokens }
1108 }
1109
1110 pub fn forced_tokens(&self, step_idx: usize) -> &[Option<u32>] {
1111 if step_idx < self.acoustic_delay {
1112 &self.pre_delay_tokens
1113 } else {
1114 &[]
1115 }
1116 }
1117}