moshi_db/
lm_generate.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5// The state struct in this module handles generation for a LM model:
6// - Apply the audio delays.
7// - Allow for teacher forcing of the audio/text tokens.
8// - Support "literal-zeros" tokens for both text and audio.
9// - Make no assumptions on the number of streams.
10// - TODO: Handle batch size > 1
11// - TODO: Support CFG.
12// - TODO: Use CPU based tensors for storing the tokens?
13
14use candle::{IndexOp, Result, Tensor};
15use candle_transformers::generation::LogitsProcessor;
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq)]
18pub enum Token {
19    Set(u32),
20    Ungenerated,
21    LiteralZero,
22}
23
24#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
25pub struct Config {
26    pub audio_delays: Vec<usize>,
27    pub audio_vocab_size: usize,
28    pub text_pad_token: u32,
29    pub text_eop_token: u32,
30    pub text_start_token: u32,
31}
32
33impl Config {
34    pub fn audio_pad_token(&self) -> u32 {
35        self.audio_vocab_size as u32 - 1
36    }
37
38    pub fn audio_codebooks(&self) -> usize {
39        self.audio_delays.len()
40    }
41
42    pub fn max_audio_delay(&self) -> usize {
43        self.audio_delays.iter().max().cloned().unwrap_or(0)
44    }
45}
46
47pub struct State {
48    model: crate::lm::LmModel,
49    audio_tokens: Vec<Vec<Token>>,
50    text_tokens: Vec<Token>,
51    audio_lp: LogitsProcessor,
52    text_lp: LogitsProcessor,
53    step_idx: usize,
54    config: Config,
55}
56
57impl State {
58    pub fn new(
59        model: crate::lm::LmModel,
60        max_step_idx: usize,
61        audio_lp: LogitsProcessor,
62        text_lp: LogitsProcessor,
63        config: Config,
64    ) -> Self {
65        // TODO(laurent): handle a batch dimension.
66        let total_len = max_step_idx + config.max_audio_delay();
67        let audio_tokens = vec![vec![Token::Ungenerated; config.audio_codebooks()]; total_len];
68        let text_tokens = vec![Token::Ungenerated; total_len];
69        Self { model, audio_tokens, text_tokens, audio_lp, text_lp, step_idx: 0, config }
70    }
71
72    pub fn step_idx(&self) -> usize {
73        self.step_idx
74    }
75
76    pub fn audio_pad_token(&self) -> u32 {
77        self.config.audio_pad_token()
78    }
79
80    pub fn config(&self) -> &Config {
81        &self.config
82    }
83
84    pub fn set_audio_tokens(&mut self, audio_tokens: &[Option<Token>]) -> Result<()> {
85        for (s, at) in self.audio_tokens[self.step_idx].iter_mut().zip(audio_tokens.iter()) {
86            if let Some(at) = at {
87                *s = *at
88            }
89        }
90        Ok(())
91    }
92
93    pub fn step(&mut self, conditions: Option<&crate::conditioner::Condition>) -> Result<()> {
94        let dev = self.model.device();
95
96        let mut forced_audio_tokens = Vec::with_capacity(self.config.audio_codebooks());
97        for (codebook, &delay) in self.config.audio_delays.iter().enumerate() {
98            let forced_token = if self.step_idx < delay {
99                Some(self.audio_pad_token())
100            } else {
101                match self.audio_tokens[self.step_idx - delay][codebook] {
102                    Token::Ungenerated | Token::LiteralZero => None,
103                    Token::Set(v) => Some(v),
104                }
105            };
106            forced_audio_tokens.push(forced_token);
107        }
108
109        let mut codes = Vec::with_capacity(self.config.audio_codebooks());
110        for (codebook, &delay) in self.config.audio_delays.iter().enumerate() {
111            let t = if self.step_idx <= delay {
112                Some(self.audio_pad_token())
113            } else {
114                match self.audio_tokens[self.step_idx - delay - 1][codebook] {
115                    Token::LiteralZero => None,
116                    Token::Set(v) => Some(v),
117                    Token::Ungenerated => {
118                        candle::bail!("internal error, ungenerated {} {codebook}", self.step_idx)
119                    }
120                }
121            };
122            let t = match t {
123                None => None,
124                Some(t) => Some(Tensor::from_vec(vec![t; 1], (1, 1), dev)?),
125            };
126            codes.push(t)
127        }
128        let text_token = if self.step_idx == 0 {
129            Some(self.config.text_start_token)
130        } else {
131            match self.text_tokens[self.step_idx - 1] {
132                Token::LiteralZero => None,
133                Token::Set(t) => Some(t),
134                Token::Ungenerated => {
135                    candle::bail!("internal error, ungenerated {} text", self.step_idx)
136                }
137            }
138        };
139        let text_token = match text_token {
140            None => None,
141            Some(t) => Some(Tensor::from_vec(vec![t; 1], (1, 1), dev)?),
142        };
143        let (text_logits, ys) =
144            self.model.forward_cond(text_token, codes, conditions, &().into())?;
145        let text_token = match self.text_tokens[self.step_idx] {
146            Token::Ungenerated => {
147                let t = self.text_lp.sample(&text_logits.i((0, 0))?)?;
148                self.text_tokens[self.step_idx] = Token::Set(t);
149                Some(t)
150            }
151            Token::Set(t) => Some(t),
152            Token::LiteralZero => None,
153        };
154        let audio_tokens = self.model.depformer_sample(
155            &ys,
156            text_token,
157            &forced_audio_tokens,
158            &mut self.audio_lp,
159        )?;
160        if let Some(audio_tokens) = audio_tokens {
161            for (codebook, audio_token) in audio_tokens.into_iter().enumerate() {
162                let delay = self.config.audio_delays[codebook];
163                if self.step_idx < delay {
164                    continue;
165                }
166                let pos = &mut self.audio_tokens[self.step_idx - delay][codebook];
167                if *pos == Token::Ungenerated {
168                    *pos = Token::Set(audio_token)
169                }
170            }
171        }
172        self.step_idx += 1;
173        if self.step_idx >= self.audio_tokens.len() {
174            candle::bail!("max step-idx reached")
175        }
176        Ok(())
177    }
178
179    pub fn last_text_token(&self) -> Result<Option<u32>> {
180        if self.step_idx == 0 {
181            Ok(None)
182        } else {
183            match self.text_tokens[self.step_idx - 1] {
184                Token::Set(t) => Ok(Some(t)),
185                Token::LiteralZero => Ok(None),
186                Token::Ungenerated => {
187                    candle::bail!("internal error, ungenerated step {}, text", self.step_idx)
188                }
189            }
190        }
191    }
192
193    pub fn last_audio_tokens(&self) -> Result<Option<Vec<u32>>> {
194        let max_audio_delay = self.config.max_audio_delay();
195        if self.step_idx <= max_audio_delay {
196            Ok(None)
197        } else {
198            let mut audio_tokens = vec![];
199            for (cb, audio_token) in
200                self.audio_tokens[self.step_idx - max_audio_delay - 1].iter().enumerate()
201            {
202                match audio_token {
203                    Token::LiteralZero => return Ok(None),
204                    Token::Set(s) => audio_tokens.push(*s),
205                    Token::Ungenerated => {
206                        candle::bail!("internal error, ungenerated step {}, cb {cb}", self.step_idx)
207                    }
208                }
209            }
210            Ok(Some(audio_tokens))
211        }
212    }
213}