1use crate::lm::LmModel;
5use crate::mimi::Mimi;
6use candle::{IndexOp, Result, Tensor};
7
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub enum AsrMsg {
10 Step { step_idx: usize, prs: Vec<Vec<f32>> },
11 Word { tokens: Vec<u32>, start_time: f64, batch_idx: usize },
12 EndWord { stop_time: f64, batch_idx: usize },
13}
14
15#[derive(Debug, Clone)]
16pub struct ItemState {
17 step_idx: usize,
18 text_token: u32,
19 word_tokens: Vec<u32>,
20 unended_word: bool,
21 last_stop_time: f64,
22 audio_pad_token: u32,
23 next_codebooks: Vec<u32>,
24}
25
26impl ItemState {
27 fn reset(&mut self) {
28 self.step_idx = 0;
29 self.text_token = 0;
30 self.word_tokens.clear();
31 self.unended_word = false;
32 self.last_stop_time = 0.;
33 self.next_codebooks.fill(self.audio_pad_token);
34 }
35
36 pub fn text_token(&self) -> u32 {
37 self.text_token
38 }
39
40 pub fn is_first_step(&self) -> bool {
41 self.step_idx == 0
42 }
43
44 pub fn next_token(&mut self, codebook_idx: usize, token: u32) -> u32 {
45 let v = self.next_codebooks[codebook_idx];
46 self.next_codebooks[codebook_idx] = token;
47 if self.is_first_step() {
48 self.audio_pad_token
49 } else {
50 v
51 }
52 }
53}
54
55pub struct State {
56 asr_delay_in_tokens: usize,
57 model_step_idx: usize,
58 temperature: f64,
59 lm: LmModel,
60 audio_tokenizer: Mimi,
61 device: candle::Device,
62 batch: Vec<ItemState>,
63}
64
65impl State {
66 pub fn new(
67 batch_size: usize,
68 asr_delay_in_tokens: usize,
69 temperature: f64,
70 audio_tokenizer: Mimi,
71 lm: LmModel,
72 ) -> Result<Self> {
73 let text_token = lm.text_start_token();
74 let device = lm.device().clone();
75 let item_state = ItemState {
76 text_token,
77 word_tokens: vec![],
78 unended_word: false,
79 step_idx: 0,
80 last_stop_time: 0.,
81 audio_pad_token: lm.audio_pad_token(),
82 next_codebooks: vec![lm.audio_pad_token(); lm.in_audio_codebooks()],
83 };
84 let mut s = Self {
85 asr_delay_in_tokens,
86 lm,
87 model_step_idx: 0,
88 audio_tokenizer,
89 temperature,
90 device,
91 batch: vec![item_state; batch_size],
92 };
93 s.reset()?;
94 Ok(s)
95 }
96
97 pub fn model_step_idx(&self) -> usize {
98 self.model_step_idx
99 }
100
101 pub fn device(&self) -> &candle::Device {
102 &self.device
103 }
104
105 pub fn batch_size(&self) -> usize {
106 self.batch.len()
107 }
108
109 pub fn asr_delay_in_tokens(&self) -> usize {
110 self.asr_delay_in_tokens
111 }
112
113 pub fn reset(&mut self) -> Result<()> {
114 self.lm.reset_state();
115 self.audio_tokenizer.reset_state();
116 self.batch.iter_mut().for_each(|s| s.reset());
117 Ok(())
118 }
119
120 pub fn step_pcm<F>(
121 &mut self,
122 pcm: Tensor,
123 conditions: Option<&crate::conditioner::Condition>,
124 mask: &crate::StreamMask,
125 f: F,
126 ) -> Result<Vec<AsrMsg>>
127 where
128 F: Fn(&[ItemState], &Tensor, &[Tensor]),
129 {
130 let audio_tokens = self.audio_tokenizer.encode_step(&pcm.into(), mask)?;
131 if let Some(audio_tokens) = audio_tokens.as_option() {
132 self.step_tokens(audio_tokens, conditions, mask, f)
133 } else {
134 Ok(vec![])
135 }
136 }
137
138 fn text_tokens(&self) -> Result<Tensor> {
139 let batch_size = self.batch_size();
140 let text_start_token = self.lm.text_start_token();
141 let dev = self.lm.device();
144 let text_tokens = self
145 .batch
146 .iter()
147 .map(|s| if s.is_first_step() { text_start_token } else { s.text_token() })
148 .collect::<Vec<_>>();
149 Tensor::from_vec(text_tokens, (batch_size, 1), dev)
150 }
151
152 pub fn step_tokens<F>(
153 &mut self,
154 audio_tokens: &Tensor,
155 conditions: Option<&crate::conditioner::Condition>,
156 mask: &crate::StreamMask,
157 f: F,
158 ) -> Result<Vec<AsrMsg>>
159 where
160 F: Fn(&[ItemState], &Tensor, &[Tensor]),
161 {
162 let (batch_size, codebooks, steps) = audio_tokens.dims3()?;
163 if batch_size != self.batch_size() {
164 candle::bail!("batch size mismatch: {batch_size} != {}", self.batch_size());
165 }
166 let mut words = vec![];
167 for step in 0..steps {
168 let audio_tokens = audio_tokens.narrow(2, step, 1)?;
169 let audio_tokens = audio_tokens.reshape((batch_size, codebooks))?.to_vec2::<u32>()?;
170 let audio_tokens = (0..codebooks)
171 .map(|codebook_idx| {
172 let audio_tokens = audio_tokens
173 .iter()
174 .zip(self.batch.iter_mut())
175 .enumerate()
176 .map(|(batch_idx, (audio_token, item))| {
177 if !mask.is_active(batch_idx) {
178 0
179 } else {
180 item.next_token(codebook_idx, audio_token[codebook_idx])
181 }
182 })
183 .collect();
184 let audio_tokens =
185 Tensor::from_vec(audio_tokens, (batch_size, 1), self.device())?;
186 Ok(audio_tokens)
187 })
188 .collect::<Result<Vec<_>>>()?;
189 let text = self.text_tokens()?;
190 f(self.batch.as_slice(), &text, &audio_tokens);
191 let audio_tokens = audio_tokens.into_iter().map(Some).collect::<Vec<_>>();
192 let (text_logits, transformer_out) =
193 self.lm.forward_cond(Some(text), audio_tokens, conditions, mask)?;
194 self.model_step_idx += 1;
195 let extra_heads = self.lm.extra_heads(&transformer_out)?;
196 let mut prs = vec![];
197 for extra_head in extra_heads.iter() {
198 let prs_ =
200 candle_nn::ops::softmax_last_dim(&extra_head.to_dtype(candle::DType::F32)?)?
201 .i((.., 0, 0))?
202 .to_vec1::<f32>()?;
203 prs.push(prs_);
204 }
205 if !prs.is_empty() {
206 words.push(AsrMsg::Step { step_idx: self.model_step_idx(), prs });
207 }
208
209 let text_tokens = if self.temperature <= 0.0 {
210 text_logits.i((.., 0))?.argmax(candle::D::Minus1)?
211 } else {
212 candle_nn::sampling::gumbel_softmax(
213 &text_logits.i((.., 0))?.to_dtype(candle::DType::F32)?,
214 self.temperature,
215 candle::D::Minus1,
216 )?
217 };
218 let text_tokens = text_tokens.to_vec1::<u32>()?;
219 for (batch_idx, (text_token, item)) in
220 text_tokens.into_iter().zip(self.batch.iter_mut()).enumerate()
221 {
222 if !mask.is_active(batch_idx) {
223 continue;
224 }
225 item.text_token = text_token;
226 item.step_idx += 1;
227 if item.step_idx >= self.asr_delay_in_tokens {
228 if text_token == 3 || text_token == 0 {
229 if !item.word_tokens.is_empty() {
230 let mut tokens = vec![];
231 std::mem::swap(&mut item.word_tokens, &mut tokens);
232 words.push(AsrMsg::Word {
233 tokens,
234 start_time: item.last_stop_time,
235 batch_idx,
236 });
237 item.unended_word = true;
238 }
239 } else {
240 item.word_tokens.push(item.text_token)
241 }
242 if item.text_token == 0 {
243 let stop_time = (item.step_idx - self.asr_delay_in_tokens) as f64 / 12.5;
244 if item.unended_word {
245 item.unended_word = false;
246 words.push(AsrMsg::EndWord { stop_time, batch_idx });
247 }
248 item.last_stop_time = stop_time;
249 }
250 }
251 }
252 }
253 Ok(words)
254 }
255
256 pub fn reset_batch_idx(&mut self, batch_idx: usize) -> Result<()> {
257 if batch_idx >= self.batch_size() {
258 candle::bail!("batch index out of range: {batch_idx} >= {}", self.batch_size());
259 }
260 self.batch[batch_idx].reset();
261 self.lm.reset_batch_idx(batch_idx, self.batch_size())?;
262 self.audio_tokenizer.reset_batch_idx(batch_idx, self.batch_size())?;
263 Ok(())
264 }
265}