moshi_db/
tts_streaming.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{IndexOp, Result, Tensor};
6use candle_transformers::generation::LogitsProcessor;
7
8use crate::transformer::CaSrc;
9
10pub const UNGENERATED: u32 = u32::MAX;
11
12#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
13pub struct Config {
14    pub acoustic_delay: usize,
15    pub text_pad_token: u32,
16    pub text_bos_token: u32,
17    pub text_eos_token: u32,
18    pub text_eop_token: u32,
19    pub text_start_token: u32,
20    pub text_audio_delay_in_tokens: usize,
21    pub max_consecutive_pads: usize,
22    pub extra_steps: usize,
23    pub speaker_cond_duration_s: f64,
24    pub speaker_cond_dim: usize,
25    pub speaker_cond_n_speakers: usize,
26}
27
28impl Config {
29    pub fn v202501() -> Self {
30        Self {
31            acoustic_delay: 2,
32            text_eop_token: 0,
33            text_bos_token: 1,
34            text_eos_token: 2,
35            text_pad_token: 3,
36            text_start_token: 8000,
37            text_audio_delay_in_tokens: 25, // aka interleaver_delay = 2s
38            max_consecutive_pads: 10,
39            extra_steps: 5,
40            speaker_cond_duration_s: 10.,
41            speaker_cond_dim: 2048,
42            speaker_cond_n_speakers: 5,
43        }
44    }
45}
46
47pub struct State {
48    model: crate::lm::LmModel,
49    ca_src: Option<CaSrc>,
50    audio_tokens: Vec<Vec<u32>>,
51    text_tokens: Vec<u32>,
52    consecutive_pads: usize,
53    audio_lp: LogitsProcessor,
54    text_lp: LogitsProcessor,
55    step_idx: usize,
56    forced_audio_tokens: crate::lm::ForcedAudioTokens,
57    cfg_alpha: Option<f64>,
58    config: Config,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum AllowedTokens {
63    Text(u32),
64    Pad,
65    PadOrEpad,
66}
67
68impl State {
69    pub fn new(
70        model: crate::lm::LmModel,
71        ca_src: Option<CaSrc>,
72        max_step_idx: usize,
73        audio_lp: LogitsProcessor,
74        text_lp: LogitsProcessor,
75        cfg_alpha: Option<f64>,
76        config: Config,
77    ) -> Self {
78        let audio_tokens: Vec<Vec<u32>> = vec![
79            vec![UNGENERATED; model.generated_audio_codebooks()];
80            max_step_idx + config.acoustic_delay
81        ];
82        let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay];
83        let forced_audio_tokens = crate::lm::ForcedAudioTokens::new(
84            config.acoustic_delay,
85            model.audio_pad_token(),
86            &[model.generated_audio_codebooks()],
87        );
88        Self {
89            model,
90            ca_src,
91            audio_tokens,
92            text_tokens,
93            consecutive_pads: 0,
94            audio_lp,
95            text_lp,
96            step_idx: 0,
97            forced_audio_tokens,
98            cfg_alpha,
99            config,
100        }
101    }
102
103    pub fn step_idx(&self) -> usize {
104        self.step_idx
105    }
106
107    fn audio_pad_token(&self) -> u32 {
108        self.model.audio_pad_token()
109    }
110
111    pub fn config(&self) -> &Config {
112        &self.config
113    }
114
115    // The acoustic tokens are written with a delay, so this can create "gaps" of UNGENERATED
116    // tokens in the case where we call `step_audio_prompt` *after* `step`.
117    pub fn step(
118        &mut self,
119        prev_text_token: u32,
120        allowed_tokens: AllowedTokens,
121        conditions: Option<&crate::conditioner::Condition>,
122    ) -> Result<u32> {
123        let mut codes = Vec::with_capacity(self.model.generated_audio_codebooks());
124        let dev = self.model.device();
125        let batch_size = if self.cfg_alpha.is_some() { 2 } else { 1 };
126        for codebook in 0..self.model.generated_audio_codebooks() {
127            let t = if codebook == 0 {
128                if self.step_idx == 0 {
129                    Some(self.audio_pad_token())
130                } else if self.step_idx <= self.config.text_audio_delay_in_tokens {
131                    // The delayed pattern for TTS is a bit special, the audio-pad tokens are used
132                    // in the same way as usual, i.e. for the first slice and until the acoustic
133                    // delay for semantic tokens.
134                    // However for the first couple seconds (set by `text_audio_delay_in_tokens`),
135                    // the tokens that are *not* audio-pad are replaced by "literal zeros".
136                    None
137                } else {
138                    Some(self.audio_tokens[self.step_idx - 1][codebook])
139                }
140            } else if self.step_idx <= self.config.acoustic_delay {
141                Some(self.audio_pad_token())
142            } else if self.step_idx
143                <= self.config.text_audio_delay_in_tokens + self.config.acoustic_delay
144            {
145                // The same comment as above applies here.
146                None
147            } else {
148                Some(self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1][codebook])
149            };
150            if t == Some(UNGENERATED) {
151                candle::bail!("internal error, ungenerated {}", self.step_idx)
152            }
153            let t = match t {
154                Some(t) => Some(Tensor::from_vec(vec![t; batch_size], (batch_size, 1), dev)?),
155                None => None,
156            };
157            codes.push(t)
158        }
159        let prev_text_token =
160            Some(Tensor::from_vec(vec![prev_text_token; batch_size], (batch_size, 1), dev)?);
161        let (text_logits, ys) = match self.ca_src.as_ref() {
162            None => self.model.forward_cond(prev_text_token, codes, conditions, &().into())?,
163            Some(ca_src) => {
164                self.model.forward_ca(prev_text_token, codes, ca_src, conditions, &().into())?
165            }
166        };
167        let text_logits = match self.cfg_alpha {
168            None => text_logits.i((0, 0))?,
169            Some(a) => match text_logits.dim(0)? {
170                2 => ((text_logits.i((0, 0))? * a)? - (text_logits.i((1, 0))? * (a - 1.))?)?,
171                b_size => candle::bail!("unexpected batch size {b_size}"),
172            },
173        };
174        // When in tts mode, there are only two possible outcomes corresponding to tokens 0 and 3.
175        // 0 -> EOP or the next text token, this is ambiguous, a list of consecutive 0s correspond to
176        //   word + EOP + word + EOP ...
177        // 3 -> pad.
178        // This will change when the simplerleaver lands.
179        let text_token = match allowed_tokens {
180            AllowedTokens::Text(v) => v,
181            AllowedTokens::Pad => self.config.text_pad_token,
182            AllowedTokens::PadOrEpad => {
183                if self.consecutive_pads > self.config.max_consecutive_pads {
184                    self.config.text_eop_token
185                } else {
186                    let text_token = self.text_lp.sample(&text_logits)?;
187                    if text_token == self.config.text_pad_token {
188                        self.config.text_pad_token
189                    } else {
190                        self.config.text_eop_token
191                    }
192                }
193            }
194        };
195        if text_token == self.config.text_pad_token {
196            self.consecutive_pads += 1
197        } else {
198            self.consecutive_pads = 0
199        }
200        self.text_tokens[self.step_idx] = text_token;
201        let last_audio_tokens = if self.step_idx < self.config.text_audio_delay_in_tokens {
202            None
203        } else {
204            match self.cfg_alpha {
205                None => self.model.depformer_sample(
206                    &ys,
207                    Some(text_token),
208                    self.forced_audio_tokens.forced_tokens(self.step_idx),
209                    &mut self.audio_lp,
210                )?,
211                Some(cfg_alpha) => self.model.depformer_sample_cfg(
212                    &ys,
213                    cfg_alpha,
214                    Some(text_token),
215                    self.forced_audio_tokens.forced_tokens(self.step_idx),
216                    &mut self.audio_lp,
217                )?,
218            }
219        };
220        let audio_pad_token = self.audio_pad_token();
221        for c_idx in 0..self.model.generated_audio_codebooks() {
222            let delay = if c_idx == 0 { 0 } else { self.config.acoustic_delay };
223            let pos = &mut self.audio_tokens[self.step_idx.saturating_sub(delay)][c_idx];
224            match last_audio_tokens.as_ref() {
225                Some(lat) => {
226                    if *pos == UNGENERATED {
227                        *pos = lat[c_idx]
228                    }
229                }
230                None => {
231                    if *pos == UNGENERATED {
232                        *pos = audio_pad_token
233                    }
234                }
235            }
236        }
237        self.step_idx += 1;
238        if self.step_idx >= self.audio_tokens.len() {
239            candle::bail!("max step-idx reached")
240        }
241        Ok(text_token)
242    }
243
244    pub fn overwrite_last_text_token(&mut self, text_token: u32) -> Result<()> {
245        if self.step_idx == 0 {
246            candle::bail!("cannot overwrite first token")
247        }
248        if text_token == UNGENERATED {
249            candle::bail!("cannot overwrite with UNGENERATED")
250        }
251        self.text_tokens[self.step_idx - 1] = text_token;
252        Ok(())
253    }
254
255    /// If include_all is set, all the time steps are returned. Otherwise only the timesteps that
256    /// have been generated are handled.
257    pub fn audio_tokens(&self, include_all: bool) -> &[Vec<u32>] {
258        if include_all {
259            &self.audio_tokens
260        } else {
261            let max_idx = usize::min(self.step_idx, self.audio_tokens.len());
262            &self.audio_tokens[..max_idx]
263        }
264    }
265
266    pub fn text_tokens(&self, include_all: bool) -> &[u32] {
267        if include_all {
268            &self.text_tokens
269        } else {
270            let max_idx = usize::min(self.step_idx, self.text_tokens.len());
271            &self.text_tokens[..max_idx]
272        }
273    }
274
275    pub fn last_audio_tokens(&self) -> Option<Vec<u32>> {
276        if self.step_idx <= self.config.acoustic_delay {
277            None
278        } else {
279            // step_idx is in advance by 1 + there is a 2 token delay on audio tokens.
280            let audio_tokens = &self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1];
281            if audio_tokens.iter().any(|v| *v >= self.audio_pad_token()) {
282                None
283            } else {
284                Some(audio_tokens.clone())
285            }
286        }
287    }
288
289    pub fn audio_codebooks(&self) -> usize {
290        self.model.generated_audio_codebooks()
291    }
292
293    pub fn device(&self) -> &candle::Device {
294        self.model.device()
295    }
296
297    pub fn dtype(&self) -> candle::DType {
298        self.model.dtype()
299    }
300}
301
302#[derive(Debug, Clone, Copy, PartialEq, Eq)]
303pub enum Speaker {
304    Main,
305    Other,
306}
307
308pub fn tokenize_prompt<E>(
309    text: &[String],
310    text_bos_token: u32,
311    text_eos_token: u32,
312    encode: impl Fn(&str) -> std::result::Result<Vec<u32>, E>,
313) -> std::result::Result<Vec<(Vec<u32>, Speaker)>, E> {
314    let mut prompt = vec![];
315    for (turn_idx, turn) in text.iter().enumerate() {
316        let (speaker, turn_token) = if turn_idx % 2 == 0 {
317            (Speaker::Main, text_bos_token)
318        } else {
319            (Speaker::Other, text_eos_token)
320        };
321        for (word_idx, word) in turn.split(' ').enumerate() {
322            let mut word = encode(word)?.into_iter().collect::<Vec<_>>();
323            if word_idx == 0 && speaker == Speaker::Main {
324                word.insert(0, turn_token)
325            }
326            if !word.is_empty() {
327                prompt.push((word, speaker))
328            }
329        }
330    }
331    Ok(prompt)
332}
333
334#[derive(Debug, Clone)]
335pub struct SpeakerEncoder {
336    mimi: crate::mimi::Mimi,
337    learnt_padding: Tensor,
338    proj: candle_nn::Linear,
339    n_speakers: usize,
340    cond_dim: usize,
341    device: candle::Device,
342    dtype: candle::DType,
343}
344
345impl SpeakerEncoder {
346    pub fn new(
347        mimi: crate::mimi::Mimi,
348        speaker_cond_dim: usize,
349        speaker_cond_n_speakers: usize,
350        dtype: candle::DType,
351        vb: candle_nn::VarBuilder,
352    ) -> Result<Self> {
353        let learnt_padding = vb.get(
354            (1, 1, speaker_cond_dim),
355            "condition_provider.conditioners.speaker_wavs.learnt_padding",
356        )?;
357        let mimi_dim = mimi.config().seanet.dimension;
358        let proj = candle_nn::linear_no_bias(
359            mimi_dim,
360            speaker_cond_dim,
361            vb.pp("condition_provider.conditioners.speaker_wavs.output_proj"),
362        )?;
363        Ok(Self {
364            mimi,
365            learnt_padding,
366            proj,
367            n_speakers: speaker_cond_n_speakers,
368            cond_dim: speaker_cond_dim,
369            device: vb.device().clone(),
370            dtype,
371        })
372    }
373
374    pub fn device(&self) -> &candle::Device {
375        &self.device
376    }
377
378    pub fn sample_rate(&self) -> f64 {
379        self.mimi.config().sample_rate
380    }
381
382    pub fn encode(&self, speakers: &[Tensor]) -> Result<Tensor> {
383        if speakers.is_empty() {
384            candle::bail!("empty speakers in encode")
385        }
386        let mut pcms = vec![];
387        for pcm in speakers.iter().take(self.n_speakers) {
388            let stdev = pcm.broadcast_sub(&pcm.mean_all()?)?.sqr()?.mean_all()?.sqrt()?;
389            let pcm = (pcm * 0.08)?.broadcast_div(&stdev)?;
390            pcms.push(pcm)
391        }
392        let n_speakers = pcms.len();
393        let pcm = Tensor::cat(&pcms, 0)?;
394        let mut mimi = self.mimi.clone();
395        mimi.reset_state();
396        let embeddings = mimi.encode_pre_quantize(&pcm)?.t()?.apply(&self.proj)?;
397        let embeddings = if n_speakers < self.n_speakers {
398            let lp =
399                embeddings.narrow(0, 0, 1)?.zeros_like()?.broadcast_add(&self.learnt_padding)?;
400            let mut embs = vec![embeddings];
401            embs.resize(self.n_speakers - n_speakers + 1, lp);
402            Tensor::cat(&embs, 0)?
403        } else {
404            embeddings
405        };
406        let embeddings = embeddings.flatten(0, 1)?.unsqueeze(0)?;
407        let embeddings = crate::tts::add_sin_embeddings(&embeddings)?;
408        embeddings.to_dtype(self.dtype)
409    }
410
411    pub fn empty(&self) -> Result<Tensor> {
412        let embeddings =
413            self.learnt_padding.broadcast_as((1, self.n_speakers * 125, self.cond_dim))?;
414        let embeddings = crate::tts::add_sin_embeddings(&embeddings)?;
415        embeddings.to_dtype(self.dtype)
416    }
417}