moshi_db/
tts.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use crate::transformer::CaSrc;
6use candle::{Context, DType, Result, Tensor, D};
7use candle_nn::{linear_no_bias, Linear, VarBuilder};
8use candle_transformers::models::t5;
9
10pub struct Config {
11    pub t5: t5::Config,
12    pub lm: crate::lm::Config,
13    pub mimi: crate::mimi::Config,
14    pub max_duration_s: f64,
15    pub speaker_cond_duration_s: f64,
16    pub max_speakers: usize,
17}
18
19impl Config {
20    pub fn v0_1(t5: t5::Config) -> Self {
21        let lm = crate::lm::Config::tts_v0_1();
22        let mimi = crate::mimi::Config::v0_1(None);
23        Self { t5, lm, mimi, max_duration_s: 60., speaker_cond_duration_s: 4., max_speakers: 5 }
24    }
25
26    pub fn v0_2(t5: t5::Config) -> Self {
27        let lm = crate::lm::Config::tts_v0_1();
28        let mimi = crate::mimi::Config::v0_1(None);
29        Self { t5, lm, mimi, max_duration_s: 60., speaker_cond_duration_s: 10., max_speakers: 2 }
30    }
31}
32
33#[derive(Clone)]
34pub struct Model {
35    t5: t5::T5EncoderModel,
36    pub lm: crate::lm::LmModel,
37    speaker_cond: Option<(crate::mimi::Mimi, Linear)>,
38    t5_proj: Linear,
39    pub sample_rate: f64,
40    frame_rate: f64,
41    audio_vocab_size: u32,
42    audio_codebooks: usize,
43    pub max_duration_s: f64,
44    max_speakers: usize,
45    end_of_gen: Option<usize>,
46}
47
48impl Model {
49    pub fn new(
50        cfg: &Config,
51        vb_t5: VarBuilder,
52        vb_lm: VarBuilder,
53        vb_speaker_cond: Option<VarBuilder>,
54    ) -> Result<Self> {
55        let t5 = t5::T5EncoderModel::load(vb_t5, &cfg.t5)?;
56        let speaker_cond = match vb_speaker_cond {
57            None => None,
58            Some(vb) => {
59                let mimi = crate::mimi::Mimi::new(cfg.mimi.clone(), vb)?;
60                let proj = linear_no_bias(
61                    cfg.mimi.seanet.dimension,
62                    cfg.lm.transformer.d_model,
63                    vb_lm.pp("condition_provider.conditioners.speaker_wavs.output_proj"),
64                )?;
65                Some((mimi, proj))
66            }
67        };
68        let t5_proj = {
69            let name = if speaker_cond.is_some() {
70                "condition_provider.conditioners.diarized_transcript_in_segment.output_proj"
71            } else {
72                "condition_provider.conditioners.transcript_in_segment.output_proj"
73            };
74            linear_no_bias(cfg.t5.d_model, cfg.lm.transformer.d_model, vb_lm.pp(name))?
75        };
76        let lm =
77            crate::lm::LmModel::new(&cfg.lm, crate::nn::MaybeQuantizedVarBuilder::Real(vb_lm))?;
78        Ok(Self {
79            t5,
80            lm,
81            speaker_cond,
82            t5_proj,
83            sample_rate: cfg.mimi.sample_rate,
84            frame_rate: cfg.mimi.frame_rate,
85            audio_vocab_size: cfg.lm.audio_vocab_size as u32,
86            audio_codebooks: cfg.lm.audio_codebooks,
87            max_duration_s: cfg.max_duration_s,
88            max_speakers: cfg.max_speakers,
89            end_of_gen: None,
90        })
91    }
92}
93
94pub fn add_sin_embeddings(xs: &Tensor) -> Result<Tensor> {
95    let target_dtype = xs.dtype();
96    let (_b_size, seq_len, dim) = xs.dims3()?;
97    let dev = xs.device();
98    let half_dim = dim / 2;
99    let positions =
100        Tensor::arange(0u32, seq_len as u32, dev)?.unsqueeze(1)?.to_dtype(DType::F32)?;
101    let inv_freq: Vec<_> =
102        (0..half_dim).map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32)).collect();
103    let inv_freq_len = inv_freq.len();
104    let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
105    let freqs = positions.broadcast_mul(&inv_freq)?;
106    let pos_emb = Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?;
107    let xs = xs.to_dtype(DType::F32)?.broadcast_add(&pos_emb)?;
108    xs.to_dtype(target_dtype)
109}
110
111impl Model {
112    pub fn conditions(
113        &mut self,
114        token_ids: &Tensor,
115        speaker_pcm: Option<&Tensor>,
116    ) -> Result<Tensor> {
117        let t5_condition =
118            self.t5.forward(token_ids)?.to_dtype(candle::DType::BF16)?.apply(&self.t5_proj)?;
119        let conditions = match speaker_pcm {
120            None => t5_condition,
121            Some(speaker_pcm) => {
122                let sc = match self.speaker_cond.as_mut() {
123                    None => candle::bail!("speaker_pcm specified without a speaker-cond model"),
124                    Some((mimi, proj)) => mimi
125                        .encode_pre_quantize(speaker_pcm)?
126                        .t()?
127                        .to_dtype(candle::DType::BF16)?
128                        .apply(proj)?,
129                };
130                let z = sc.zeros_like()?;
131                let mut c1 = vec![&t5_condition, &sc];
132                let mut c2 = vec![&t5_condition, &z];
133                for _i in 0..self.max_speakers - 1 {
134                    c1.push(&z);
135                    c2.push(&z);
136                }
137                let c1 = Tensor::cat(&c1, 1)?;
138                let c2 = Tensor::cat(&c2, 1)?;
139                let xs = Tensor::cat(&[&c1, &c2], 0)?;
140                add_sin_embeddings(&xs)?
141            }
142        };
143        Ok(conditions)
144    }
145
146    pub fn sample(&mut self, conditions: &Tensor, cfg_alpha: f64) -> Result<Vec<Vec<u32>>> {
147        let lp = candle_transformers::generation::LogitsProcessor::from_sampling(
148            299792458,
149            candle_transformers::generation::Sampling::TopK { k: 100, temperature: 0.8 },
150        );
151        self.sample_lp(conditions, cfg_alpha, lp)
152    }
153
154    pub fn sample_lp(
155        &mut self,
156        conditions: &Tensor,
157        cfg_alpha: f64,
158        mut lp: candle_transformers::generation::LogitsProcessor,
159    ) -> Result<Vec<Vec<u32>>> {
160        let max_steps = (self.max_duration_s * self.frame_rate) as usize + 1;
161        let audio_codebooks = self.audio_codebooks;
162        let audio_vocab_size = self.audio_vocab_size;
163        let mut audio_tokens: Vec<Vec<u32>> = vec![vec![u32::MAX; audio_codebooks]; max_steps + 2];
164        let forced_audio_tokens = crate::lm::ForcedAudioTokens::new(
165            /* acoustic_delay= */ 2,
166            self.lm.audio_pad_token(),
167            &[audio_codebooks],
168        );
169        let quantizer_bins = audio_vocab_size - 2; // 2048
170        for step_idx in 0..(max_steps + 2) {
171            let mut codes = Vec::with_capacity(audio_codebooks);
172            for codebook in 0..audio_codebooks {
173                let t = if codebook == 0 {
174                    if step_idx == 0 {
175                        audio_vocab_size - 1
176                    } else {
177                        audio_tokens[step_idx - 1][0]
178                    }
179                } else if step_idx <= 2 {
180                    audio_vocab_size - 1
181                } else {
182                    audio_tokens[step_idx - 3][codebook]
183                };
184                let t = Tensor::new(&[t], conditions.device())?.unsqueeze(0)?;
185                codes.push(Some(t))
186            }
187            let (_text_logits, ys) = self.lm.forward_ca(
188                None,
189                codes,
190                &CaSrc::Tokens(conditions.clone()),
191                None,
192                &().into(),
193            )?;
194            let last_audio_tokens = if self.speaker_cond.is_some() {
195                self.lm.depformer_sample_cfg(
196                    &ys,
197                    cfg_alpha,
198                    None,
199                    forced_audio_tokens.forced_tokens(step_idx),
200                    &mut lp,
201                )?
202            } else {
203                self.lm.depformer_sample(
204                    &ys,
205                    None,
206                    forced_audio_tokens.forced_tokens(step_idx),
207                    &mut lp,
208                )?
209            };
210            let last_audio_tokens = last_audio_tokens.context("no depformer")?;
211            for (c_idx, token) in last_audio_tokens.into_iter().enumerate() {
212                if step_idx > 0 && token >= quantizer_bins && self.end_of_gen.is_none() {
213                    // Continue generating for two steps to get the final acoustic tokens.
214                    self.end_of_gen = Some(step_idx + 2)
215                }
216                let delay = if c_idx == 0 { 0 } else { 2 };
217                audio_tokens[step_idx.saturating_sub(delay)][c_idx] = token
218            }
219            if Some(step_idx) == self.end_of_gen {
220                break;
221            }
222        }
223        Ok(audio_tokens)
224    }
225}