moshi_db/
lm_generate_multistream.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{IndexOp, Tensor};
6use candle_transformers::generation::LogitsProcessor;
7
8use crate::transformer::CaSrc;
9
10pub const UNGENERATED: u32 = u32::MAX;
11
12#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
13pub struct Config {
14    pub generated_audio_codebooks: usize,
15    pub input_audio_codebooks: usize,
16    pub audio_vocab_size: usize,
17    pub acoustic_delay: usize,
18    pub text_pad_token: u32,
19    pub text_eop_token: u32,
20    pub text_start_token: u32,
21}
22
23impl Config {
24    pub fn v0_1() -> Self {
25        Self {
26            generated_audio_codebooks: 8,
27            input_audio_codebooks: 8,
28            audio_vocab_size: 2049,
29            acoustic_delay: 2,
30            text_eop_token: 0,
31            text_pad_token: 3,
32            text_start_token: 32000,
33        }
34    }
35
36    pub fn v0_1_two_ways() -> Self {
37        Self {
38            generated_audio_codebooks: 16,
39            input_audio_codebooks: 0,
40            audio_vocab_size: 2049,
41            acoustic_delay: 2,
42            text_eop_token: 0,
43            text_pad_token: 3,
44            text_start_token: 32000,
45        }
46    }
47
48    pub fn v0_1_one_way() -> Self {
49        Self {
50            generated_audio_codebooks: 8,
51            input_audio_codebooks: 0,
52            audio_vocab_size: 2049,
53            acoustic_delay: 2,
54            text_eop_token: 0,
55            text_pad_token: 3,
56            text_start_token: 32000,
57        }
58    }
59
60    pub fn audio_pad_token(&self) -> u32 {
61        self.audio_vocab_size as u32 - 1
62    }
63
64    pub fn total_audio_codebooks(&self) -> usize {
65        self.generated_audio_codebooks + self.input_audio_codebooks
66    }
67}
68
69pub struct State {
70    model: crate::lm::LmModel,
71    audio_tokens: Vec<Vec<u32>>,
72    text_tokens: Vec<u32>,
73    audio_lp: LogitsProcessor,
74    text_lp: LogitsProcessor,
75    step_idx: usize,
76    pad_mult: Option<f32>,
77    // For repetition penalty, we provide the context len (in text tokens) and the penalty.
78    repetition_penalty: Option<(usize, f32)>,
79    forced_audio_tokens: crate::lm::ForcedAudioTokens,
80    user_rating: u32,
81    cfg_alpha: Option<f64>,
82    config: Config,
83}
84
85impl State {
86    #[allow(clippy::too_many_arguments)]
87    pub fn new(
88        model: crate::lm::LmModel,
89        max_step_idx: usize,
90        audio_lp: LogitsProcessor,
91        text_lp: LogitsProcessor,
92        pad_mult: Option<f32>,
93        repetition_penalty: Option<(usize, f32)>,
94        cfg_alpha: Option<f64>,
95        config: Config,
96    ) -> Self {
97        let audio_tokens: Vec<Vec<u32>> = vec![
98            vec![UNGENERATED; config.total_audio_codebooks()];
99            max_step_idx + config.acoustic_delay
100        ];
101        let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay];
102        let forced_audio_tokens = crate::lm::ForcedAudioTokens::new(
103            config.acoustic_delay,
104            config.audio_pad_token(),
105            &[8, 8],
106        );
107        Self {
108            model,
109            audio_tokens,
110            text_tokens,
111            audio_lp,
112            text_lp,
113            step_idx: 0,
114            pad_mult,
115            repetition_penalty,
116            forced_audio_tokens,
117            user_rating: 0, // 0 indicates no ratings have been submitted from the front
118            cfg_alpha,
119            config,
120        }
121    }
122
123    pub fn step_idx(&self) -> usize {
124        self.step_idx
125    }
126
127    fn audio_pad_token(&self) -> u32 {
128        self.config.audio_pad_token()
129    }
130
131    pub fn config(&self) -> &Config {
132        &self.config
133    }
134
135    pub fn user_rating(&self) -> u32 {
136        self.user_rating
137    }
138    pub fn set_user_rating(&mut self, grade: u32) {
139        self.user_rating = grade
140    }
141
142    fn apply_repetition_penalty(&self, logits: Tensor) -> candle::Result<Tensor> {
143        let logits = match self.repetition_penalty {
144            None => logits,
145            Some((_, 1.)) => logits,
146            Some((context_size, penalty)) => {
147                let device = logits.device();
148                let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
149                let mut already_seen = std::collections::HashSet::new();
150                let mut non_pad_tokens = 0;
151                for &token_id in self.text_tokens(false).iter().rev() {
152                    if token_id == self.config.text_pad_token
153                        || token_id == self.config.text_eop_token
154                        || token_id == self.config.text_start_token
155                    {
156                        continue;
157                    }
158                    // Look at the last [context_size] tokens at most, count all tokens there even
159                    // if we already saw them.
160                    if non_pad_tokens >= context_size {
161                        break;
162                    }
163                    non_pad_tokens += 1;
164
165                    if already_seen.contains(&token_id) {
166                        continue;
167                    }
168
169                    already_seen.insert(token_id);
170                    if let Some(logit) = logits.get_mut(token_id as usize) {
171                        if *logit >= 0. {
172                            *logit /= penalty
173                        } else {
174                            *logit *= penalty
175                        }
176                    }
177                }
178                let logits_len = logits.len();
179                Tensor::from_vec(logits, logits_len, device)?
180            }
181        };
182        Ok(logits)
183    }
184
185    // The acoustic tokens are written with a delay, so this can create "gaps" of UNGENERATED
186    // tokens in the case where we call `step_audio_prompt` *after* `step`.
187    pub fn step_(
188        &mut self,
189        text_token: Option<u32>,
190        input_audio_tokens: &[u32],
191        force_text_token: Option<u32>,
192        ca_src: Option<&CaSrc>,
193        conditions: Option<&crate::conditioner::Condition>,
194    ) -> candle::Result<u32> {
195        let mut codes = Vec::with_capacity(self.config.total_audio_codebooks());
196        let dev = self.model.device();
197        for (c_idx, &t) in input_audio_tokens.iter().enumerate() {
198            self.audio_tokens[self.step_idx][c_idx + self.config.generated_audio_codebooks] = t
199        }
200        let batch_size = if self.cfg_alpha.is_some() { 2 } else { 1 };
201        for codebook in 0..self.config.total_audio_codebooks() {
202            let t = if codebook == 0 || codebook == self.config.generated_audio_codebooks {
203                if self.step_idx == 0 {
204                    self.audio_pad_token()
205                } else {
206                    self.audio_tokens[self.step_idx - 1][codebook]
207                }
208            } else if self.step_idx <= self.config.acoustic_delay {
209                self.audio_pad_token()
210            } else {
211                self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1][codebook]
212            };
213            if t == UNGENERATED {
214                candle::bail!("internal error, ungenerated {} {codebook}", self.step_idx)
215            }
216            let t = Tensor::from_vec(vec![t; batch_size], (batch_size, 1), dev)?;
217            codes.push(Some(t))
218        }
219        let text_token = match text_token {
220            Some(text_token) => {
221                Some(Tensor::from_vec(vec![text_token; batch_size], (batch_size, 1), dev)?)
222            }
223            None => None,
224        };
225        let (text_logits, ys) = match ca_src.as_ref() {
226            None => {
227                let (logits, ys) =
228                    self.model.forward_cond(text_token, codes, conditions, &().into())?;
229                let logits = match self.cfg_alpha {
230                    None => logits.i((0, 0))?,
231                    Some(a) => match logits.dim(0)? {
232                        2 => ((logits.i((0, 0))? * a)? - (logits.i((1, 0))? * (a - 1.))?)?,
233                        b_size => candle::bail!("unexpected batch size {b_size}"),
234                    },
235                };
236                (logits, ys)
237            }
238            Some(ca_src) => {
239                if self.cfg_alpha.is_some() {
240                    candle::bail!("cfg is not supported with cross attention")
241                }
242                let (logits, ys) =
243                    self.model.forward_ca(text_token, codes, ca_src, None, &().into())?;
244                (logits.i((0, 0))?, ys)
245            }
246        };
247        let text_logits = self.apply_repetition_penalty(text_logits)?;
248        let text_token = match force_text_token {
249            Some(tt) => tt,
250            None => self.text_lp.sample_f(&text_logits, |prs| {
251                if let Some(pad_mult) = self.pad_mult.as_ref() {
252                    prs[self.config.text_pad_token as usize] *= f32::exp(*pad_mult);
253                }
254            })?,
255        };
256        self.text_tokens[self.step_idx] = text_token;
257        let last_audio_tokens = match self.cfg_alpha {
258            None => self.model.depformer_sample(
259                &ys,
260                Some(text_token),
261                self.forced_audio_tokens.forced_tokens(self.step_idx),
262                &mut self.audio_lp,
263            )?,
264            Some(cfg_alpha) => self.model.depformer_sample_cfg(
265                &ys,
266                cfg_alpha,
267                Some(text_token),
268                self.forced_audio_tokens.forced_tokens(self.step_idx),
269                &mut self.audio_lp,
270            )?,
271        };
272        let audio_pad_token = self.audio_pad_token();
273        for c_idx in 0..self.config.generated_audio_codebooks {
274            let delay = if c_idx == 0 || c_idx == self.config.generated_audio_codebooks {
275                0
276            } else {
277                self.config.acoustic_delay
278            };
279            let pos = &mut self.audio_tokens[self.step_idx.saturating_sub(delay)][c_idx];
280            // Overwrite existing positions even if there are non-UNGENERATED values. This
281            // actually happens for the first few slices because of the saturating_sub.
282            *pos = last_audio_tokens.as_ref().map_or(audio_pad_token, |l| l[c_idx]);
283        }
284        self.step_idx += 1;
285        if self.step_idx >= self.audio_tokens.len() {
286            candle::bail!("max step-idx reached")
287        }
288        Ok(text_token)
289    }
290
291    pub fn step_without_ca_src(
292        &mut self,
293        text_token: u32,
294        input_audio_tokens: &[u32],
295        force_text_token: Option<u32>,
296    ) -> candle::Result<u32> {
297        self.step_(Some(text_token), input_audio_tokens, force_text_token, None, None)
298    }
299
300    pub fn step(
301        &mut self,
302        text_token: u32,
303        input_audio_tokens: &[u32],
304        force_text_token: Option<u32>,
305        ca_src: Option<&CaSrc>,
306    ) -> candle::Result<u32> {
307        self.step_(Some(text_token), input_audio_tokens, force_text_token, ca_src, None)
308    }
309
310    /// If include_all is set, all the time steps are returned. Otherwise only the timesteps that
311    /// have been generated are handled.
312    pub fn audio_tokens(&self, include_all: bool) -> &[Vec<u32>] {
313        if include_all {
314            &self.audio_tokens
315        } else {
316            let max_idx = usize::min(self.step_idx, self.audio_tokens.len());
317            &self.audio_tokens[..max_idx]
318        }
319    }
320
321    pub fn text_tokens(&self, include_all: bool) -> &[u32] {
322        if include_all {
323            &self.text_tokens
324        } else {
325            let max_idx = usize::min(self.step_idx, self.text_tokens.len());
326            &self.text_tokens[..max_idx]
327        }
328    }
329
330    pub fn last_audio_tokens(&self) -> Option<Vec<u32>> {
331        if self.step_idx <= self.config.acoustic_delay {
332            None
333        } else {
334            // step_idx is in advance by 1 + there is a 2 token delay on audio tokens.
335            let audio_tokens = &self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1];
336            if audio_tokens.iter().any(|v| *v as usize >= self.config.audio_vocab_size - 1) {
337                None
338            } else {
339                Some(audio_tokens.clone())
340            }
341        }
342    }
343}