moshi_db/
asr.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4use crate::lm::LmModel;
5use crate::mimi::Mimi;
6use candle::{IndexOp, Result, Tensor};
7
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub enum AsrMsg {
10    Step { step_idx: usize, prs: Vec<Vec<f32>> },
11    Word { tokens: Vec<u32>, start_time: f64, batch_idx: usize },
12    EndWord { stop_time: f64, batch_idx: usize },
13}
14
15#[derive(Debug, Clone)]
16pub struct ItemState {
17    step_idx: usize,
18    text_token: u32,
19    word_tokens: Vec<u32>,
20    unended_word: bool,
21    last_stop_time: f64,
22    audio_pad_token: u32,
23    next_codebooks: Vec<u32>,
24}
25
26impl ItemState {
27    fn reset(&mut self) {
28        self.step_idx = 0;
29        self.text_token = 0;
30        self.word_tokens.clear();
31        self.unended_word = false;
32        self.last_stop_time = 0.;
33        self.next_codebooks.fill(self.audio_pad_token);
34    }
35
36    pub fn text_token(&self) -> u32 {
37        self.text_token
38    }
39
40    pub fn is_first_step(&self) -> bool {
41        self.step_idx == 0
42    }
43
44    pub fn next_token(&mut self, codebook_idx: usize, token: u32) -> u32 {
45        let v = self.next_codebooks[codebook_idx];
46        self.next_codebooks[codebook_idx] = token;
47        if self.is_first_step() {
48            self.audio_pad_token
49        } else {
50            v
51        }
52    }
53}
54
55pub struct State {
56    asr_delay_in_tokens: usize,
57    model_step_idx: usize,
58    temperature: f64,
59    lm: LmModel,
60    audio_tokenizer: Mimi,
61    device: candle::Device,
62    batch: Vec<ItemState>,
63}
64
65impl State {
66    pub fn new(
67        batch_size: usize,
68        asr_delay_in_tokens: usize,
69        temperature: f64,
70        audio_tokenizer: Mimi,
71        lm: LmModel,
72    ) -> Result<Self> {
73        let text_token = lm.text_start_token();
74        let device = lm.device().clone();
75        let item_state = ItemState {
76            text_token,
77            word_tokens: vec![],
78            unended_word: false,
79            step_idx: 0,
80            last_stop_time: 0.,
81            audio_pad_token: lm.audio_pad_token(),
82            next_codebooks: vec![lm.audio_pad_token(); lm.in_audio_codebooks()],
83        };
84        let mut s = Self {
85            asr_delay_in_tokens,
86            lm,
87            model_step_idx: 0,
88            audio_tokenizer,
89            temperature,
90            device,
91            batch: vec![item_state; batch_size],
92        };
93        s.reset()?;
94        Ok(s)
95    }
96
97    pub fn model_step_idx(&self) -> usize {
98        self.model_step_idx
99    }
100
101    pub fn device(&self) -> &candle::Device {
102        &self.device
103    }
104
105    pub fn batch_size(&self) -> usize {
106        self.batch.len()
107    }
108
109    pub fn asr_delay_in_tokens(&self) -> usize {
110        self.asr_delay_in_tokens
111    }
112
113    pub fn reset(&mut self) -> Result<()> {
114        self.lm.reset_state();
115        self.audio_tokenizer.reset_state();
116        self.batch.iter_mut().for_each(|s| s.reset());
117        Ok(())
118    }
119
120    pub fn step_pcm<F>(
121        &mut self,
122        pcm: Tensor,
123        conditions: Option<&crate::conditioner::Condition>,
124        mask: &crate::StreamMask,
125        f: F,
126    ) -> Result<Vec<AsrMsg>>
127    where
128        F: Fn(&[ItemState], &Tensor, &[Tensor]),
129    {
130        let audio_tokens = self.audio_tokenizer.encode_step(&pcm.into(), mask)?;
131        if let Some(audio_tokens) = audio_tokens.as_option() {
132            self.step_tokens(audio_tokens, conditions, mask, f)
133        } else {
134            Ok(vec![])
135        }
136    }
137
138    fn text_tokens(&self) -> Result<Tensor> {
139        let batch_size = self.batch_size();
140        let text_start_token = self.lm.text_start_token();
141        // We used to have literal 0s for the first asr_delay_in_tokens - 1 steps
142        // This is not the case anymore.
143        let dev = self.lm.device();
144        let text_tokens = self
145            .batch
146            .iter()
147            .map(|s| if s.is_first_step() { text_start_token } else { s.text_token() })
148            .collect::<Vec<_>>();
149        Tensor::from_vec(text_tokens, (batch_size, 1), dev)
150    }
151
152    pub fn step_tokens<F>(
153        &mut self,
154        audio_tokens: &Tensor,
155        conditions: Option<&crate::conditioner::Condition>,
156        mask: &crate::StreamMask,
157        f: F,
158    ) -> Result<Vec<AsrMsg>>
159    where
160        F: Fn(&[ItemState], &Tensor, &[Tensor]),
161    {
162        let (batch_size, codebooks, steps) = audio_tokens.dims3()?;
163        if batch_size != self.batch_size() {
164            candle::bail!("batch size mismatch: {batch_size} != {}", self.batch_size());
165        }
166        let mut words = vec![];
167        for step in 0..steps {
168            let audio_tokens = audio_tokens.narrow(2, step, 1)?;
169            let audio_tokens = audio_tokens.reshape((batch_size, codebooks))?.to_vec2::<u32>()?;
170            let audio_tokens = (0..codebooks)
171                .map(|codebook_idx| {
172                    let audio_tokens = audio_tokens
173                        .iter()
174                        .zip(self.batch.iter_mut())
175                        .enumerate()
176                        .map(|(batch_idx, (audio_token, item))| {
177                            if !mask.is_active(batch_idx) {
178                                0
179                            } else {
180                                item.next_token(codebook_idx, audio_token[codebook_idx])
181                            }
182                        })
183                        .collect();
184                    let audio_tokens =
185                        Tensor::from_vec(audio_tokens, (batch_size, 1), self.device())?;
186                    Ok(audio_tokens)
187                })
188                .collect::<Result<Vec<_>>>()?;
189            let text = self.text_tokens()?;
190            f(self.batch.as_slice(), &text, &audio_tokens);
191            let audio_tokens = audio_tokens.into_iter().map(Some).collect::<Vec<_>>();
192            let (text_logits, transformer_out) =
193                self.lm.forward_cond(Some(text), audio_tokens, conditions, mask)?;
194            self.model_step_idx += 1;
195            let extra_heads = self.lm.extra_heads(&transformer_out)?;
196            let mut prs = vec![];
197            for extra_head in extra_heads.iter() {
198                // Only retrieve the first element for each extra-head.
199                let prs_ =
200                    candle_nn::ops::softmax_last_dim(&extra_head.to_dtype(candle::DType::F32)?)?
201                        .i((.., 0, 0))?
202                        .to_vec1::<f32>()?;
203                prs.push(prs_);
204            }
205            if !prs.is_empty() {
206                words.push(AsrMsg::Step { step_idx: self.model_step_idx(), prs });
207            }
208
209            let text_tokens = if self.temperature <= 0.0 {
210                text_logits.i((.., 0))?.argmax(candle::D::Minus1)?
211            } else {
212                candle_nn::sampling::gumbel_softmax(
213                    &text_logits.i((.., 0))?.to_dtype(candle::DType::F32)?,
214                    self.temperature,
215                    candle::D::Minus1,
216                )?
217            };
218            let text_tokens = text_tokens.to_vec1::<u32>()?;
219            for (batch_idx, (text_token, item)) in
220                text_tokens.into_iter().zip(self.batch.iter_mut()).enumerate()
221            {
222                if !mask.is_active(batch_idx) {
223                    continue;
224                }
225                item.text_token = text_token;
226                item.step_idx += 1;
227                if item.step_idx >= self.asr_delay_in_tokens {
228                    if text_token == 3 || text_token == 0 {
229                        if !item.word_tokens.is_empty() {
230                            let mut tokens = vec![];
231                            std::mem::swap(&mut item.word_tokens, &mut tokens);
232                            words.push(AsrMsg::Word {
233                                tokens,
234                                start_time: item.last_stop_time,
235                                batch_idx,
236                            });
237                            item.unended_word = true;
238                        }
239                    } else {
240                        item.word_tokens.push(item.text_token)
241                    }
242                    if item.text_token == 0 {
243                        let stop_time = (item.step_idx - self.asr_delay_in_tokens) as f64 / 12.5;
244                        if item.unended_word {
245                            item.unended_word = false;
246                            words.push(AsrMsg::EndWord { stop_time, batch_idx });
247                        }
248                        item.last_stop_time = stop_time;
249                    }
250                }
251            }
252        }
253        Ok(words)
254    }
255
256    pub fn reset_batch_idx(&mut self, batch_idx: usize) -> Result<()> {
257        if batch_idx >= self.batch_size() {
258            candle::bail!("batch index out of range: {batch_idx} >= {}", self.batch_size());
259        }
260        self.batch[batch_idx].reset();
261        self.lm.reset_batch_idx(batch_idx, self.batch_size())?;
262        self.audio_tokenizer.reset_batch_idx(batch_idx, self.batch_size())?;
263        Ok(())
264    }
265}