1use candle::{IndexOp, Result, Tensor};
6use candle_transformers::generation::LogitsProcessor;
7
8use crate::transformer::CaSrc;
9
10pub const UNGENERATED: u32 = u32::MAX;
11
12#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
13pub struct Config {
14 pub acoustic_delay: usize,
15 pub text_pad_token: u32,
16 pub text_bos_token: u32,
17 pub text_eos_token: u32,
18 pub text_eop_token: u32,
19 pub text_start_token: u32,
20 pub text_audio_delay_in_tokens: usize,
21 pub max_consecutive_pads: usize,
22 pub extra_steps: usize,
23 pub speaker_cond_duration_s: f64,
24 pub speaker_cond_dim: usize,
25 pub speaker_cond_n_speakers: usize,
26}
27
28impl Config {
29 pub fn v202501() -> Self {
30 Self {
31 acoustic_delay: 2,
32 text_eop_token: 0,
33 text_bos_token: 1,
34 text_eos_token: 2,
35 text_pad_token: 3,
36 text_start_token: 8000,
37 text_audio_delay_in_tokens: 25, max_consecutive_pads: 10,
39 extra_steps: 5,
40 speaker_cond_duration_s: 10.,
41 speaker_cond_dim: 2048,
42 speaker_cond_n_speakers: 5,
43 }
44 }
45}
46
47pub struct State {
48 model: crate::lm::LmModel,
49 ca_src: Option<CaSrc>,
50 audio_tokens: Vec<Vec<u32>>,
51 text_tokens: Vec<u32>,
52 consecutive_pads: usize,
53 audio_lp: LogitsProcessor,
54 text_lp: LogitsProcessor,
55 step_idx: usize,
56 forced_audio_tokens: crate::lm::ForcedAudioTokens,
57 cfg_alpha: Option<f64>,
58 config: Config,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum AllowedTokens {
63 Text(u32),
64 Pad,
65 PadOrEpad,
66}
67
68impl State {
69 pub fn new(
70 model: crate::lm::LmModel,
71 ca_src: Option<CaSrc>,
72 max_step_idx: usize,
73 audio_lp: LogitsProcessor,
74 text_lp: LogitsProcessor,
75 cfg_alpha: Option<f64>,
76 config: Config,
77 ) -> Self {
78 let audio_tokens: Vec<Vec<u32>> = vec![
79 vec![UNGENERATED; model.generated_audio_codebooks()];
80 max_step_idx + config.acoustic_delay
81 ];
82 let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay];
83 let forced_audio_tokens = crate::lm::ForcedAudioTokens::new(
84 config.acoustic_delay,
85 model.audio_pad_token(),
86 &[model.generated_audio_codebooks()],
87 );
88 Self {
89 model,
90 ca_src,
91 audio_tokens,
92 text_tokens,
93 consecutive_pads: 0,
94 audio_lp,
95 text_lp,
96 step_idx: 0,
97 forced_audio_tokens,
98 cfg_alpha,
99 config,
100 }
101 }
102
103 pub fn step_idx(&self) -> usize {
104 self.step_idx
105 }
106
107 fn audio_pad_token(&self) -> u32 {
108 self.model.audio_pad_token()
109 }
110
111 pub fn config(&self) -> &Config {
112 &self.config
113 }
114
115 pub fn step(
118 &mut self,
119 prev_text_token: u32,
120 allowed_tokens: AllowedTokens,
121 conditions: Option<&crate::conditioner::Condition>,
122 ) -> Result<u32> {
123 let mut codes = Vec::with_capacity(self.model.generated_audio_codebooks());
124 let dev = self.model.device();
125 let batch_size = if self.cfg_alpha.is_some() { 2 } else { 1 };
126 for codebook in 0..self.model.generated_audio_codebooks() {
127 let t = if codebook == 0 {
128 if self.step_idx == 0 {
129 Some(self.audio_pad_token())
130 } else if self.step_idx <= self.config.text_audio_delay_in_tokens {
131 None
137 } else {
138 Some(self.audio_tokens[self.step_idx - 1][codebook])
139 }
140 } else if self.step_idx <= self.config.acoustic_delay {
141 Some(self.audio_pad_token())
142 } else if self.step_idx
143 <= self.config.text_audio_delay_in_tokens + self.config.acoustic_delay
144 {
145 None
147 } else {
148 Some(self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1][codebook])
149 };
150 if t == Some(UNGENERATED) {
151 candle::bail!("internal error, ungenerated {}", self.step_idx)
152 }
153 let t = match t {
154 Some(t) => Some(Tensor::from_vec(vec![t; batch_size], (batch_size, 1), dev)?),
155 None => None,
156 };
157 codes.push(t)
158 }
159 let prev_text_token =
160 Some(Tensor::from_vec(vec![prev_text_token; batch_size], (batch_size, 1), dev)?);
161 let (text_logits, ys) = match self.ca_src.as_ref() {
162 None => self.model.forward_cond(prev_text_token, codes, conditions, &().into())?,
163 Some(ca_src) => {
164 self.model.forward_ca(prev_text_token, codes, ca_src, conditions, &().into())?
165 }
166 };
167 let text_logits = match self.cfg_alpha {
168 None => text_logits.i((0, 0))?,
169 Some(a) => match text_logits.dim(0)? {
170 2 => ((text_logits.i((0, 0))? * a)? - (text_logits.i((1, 0))? * (a - 1.))?)?,
171 b_size => candle::bail!("unexpected batch size {b_size}"),
172 },
173 };
174 let text_token = match allowed_tokens {
180 AllowedTokens::Text(v) => v,
181 AllowedTokens::Pad => self.config.text_pad_token,
182 AllowedTokens::PadOrEpad => {
183 if self.consecutive_pads > self.config.max_consecutive_pads {
184 self.config.text_eop_token
185 } else {
186 let text_token = self.text_lp.sample(&text_logits)?;
187 if text_token == self.config.text_pad_token {
188 self.config.text_pad_token
189 } else {
190 self.config.text_eop_token
191 }
192 }
193 }
194 };
195 if text_token == self.config.text_pad_token {
196 self.consecutive_pads += 1
197 } else {
198 self.consecutive_pads = 0
199 }
200 self.text_tokens[self.step_idx] = text_token;
201 let last_audio_tokens = if self.step_idx < self.config.text_audio_delay_in_tokens {
202 None
203 } else {
204 match self.cfg_alpha {
205 None => self.model.depformer_sample(
206 &ys,
207 Some(text_token),
208 self.forced_audio_tokens.forced_tokens(self.step_idx),
209 &mut self.audio_lp,
210 )?,
211 Some(cfg_alpha) => self.model.depformer_sample_cfg(
212 &ys,
213 cfg_alpha,
214 Some(text_token),
215 self.forced_audio_tokens.forced_tokens(self.step_idx),
216 &mut self.audio_lp,
217 )?,
218 }
219 };
220 let audio_pad_token = self.audio_pad_token();
221 for c_idx in 0..self.model.generated_audio_codebooks() {
222 let delay = if c_idx == 0 { 0 } else { self.config.acoustic_delay };
223 let pos = &mut self.audio_tokens[self.step_idx.saturating_sub(delay)][c_idx];
224 match last_audio_tokens.as_ref() {
225 Some(lat) => {
226 if *pos == UNGENERATED {
227 *pos = lat[c_idx]
228 }
229 }
230 None => {
231 if *pos == UNGENERATED {
232 *pos = audio_pad_token
233 }
234 }
235 }
236 }
237 self.step_idx += 1;
238 if self.step_idx >= self.audio_tokens.len() {
239 candle::bail!("max step-idx reached")
240 }
241 Ok(text_token)
242 }
243
244 pub fn overwrite_last_text_token(&mut self, text_token: u32) -> Result<()> {
245 if self.step_idx == 0 {
246 candle::bail!("cannot overwrite first token")
247 }
248 if text_token == UNGENERATED {
249 candle::bail!("cannot overwrite with UNGENERATED")
250 }
251 self.text_tokens[self.step_idx - 1] = text_token;
252 Ok(())
253 }
254
255 pub fn audio_tokens(&self, include_all: bool) -> &[Vec<u32>] {
258 if include_all {
259 &self.audio_tokens
260 } else {
261 let max_idx = usize::min(self.step_idx, self.audio_tokens.len());
262 &self.audio_tokens[..max_idx]
263 }
264 }
265
266 pub fn text_tokens(&self, include_all: bool) -> &[u32] {
267 if include_all {
268 &self.text_tokens
269 } else {
270 let max_idx = usize::min(self.step_idx, self.text_tokens.len());
271 &self.text_tokens[..max_idx]
272 }
273 }
274
275 pub fn last_audio_tokens(&self) -> Option<Vec<u32>> {
276 if self.step_idx <= self.config.acoustic_delay {
277 None
278 } else {
279 let audio_tokens = &self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1];
281 if audio_tokens.iter().any(|v| *v >= self.audio_pad_token()) {
282 None
283 } else {
284 Some(audio_tokens.clone())
285 }
286 }
287 }
288
289 pub fn audio_codebooks(&self) -> usize {
290 self.model.generated_audio_codebooks()
291 }
292
293 pub fn device(&self) -> &candle::Device {
294 self.model.device()
295 }
296
297 pub fn dtype(&self) -> candle::DType {
298 self.model.dtype()
299 }
300}
301
302#[derive(Debug, Clone, Copy, PartialEq, Eq)]
303pub enum Speaker {
304 Main,
305 Other,
306}
307
308pub fn tokenize_prompt<E>(
309 text: &[String],
310 text_bos_token: u32,
311 text_eos_token: u32,
312 encode: impl Fn(&str) -> std::result::Result<Vec<u32>, E>,
313) -> std::result::Result<Vec<(Vec<u32>, Speaker)>, E> {
314 let mut prompt = vec![];
315 for (turn_idx, turn) in text.iter().enumerate() {
316 let (speaker, turn_token) = if turn_idx % 2 == 0 {
317 (Speaker::Main, text_bos_token)
318 } else {
319 (Speaker::Other, text_eos_token)
320 };
321 for (word_idx, word) in turn.split(' ').enumerate() {
322 let mut word = encode(word)?.into_iter().collect::<Vec<_>>();
323 if word_idx == 0 && speaker == Speaker::Main {
324 word.insert(0, turn_token)
325 }
326 if !word.is_empty() {
327 prompt.push((word, speaker))
328 }
329 }
330 }
331 Ok(prompt)
332}
333
334#[derive(Debug, Clone)]
335pub struct SpeakerEncoder {
336 mimi: crate::mimi::Mimi,
337 learnt_padding: Tensor,
338 proj: candle_nn::Linear,
339 n_speakers: usize,
340 cond_dim: usize,
341 device: candle::Device,
342 dtype: candle::DType,
343}
344
345impl SpeakerEncoder {
346 pub fn new(
347 mimi: crate::mimi::Mimi,
348 speaker_cond_dim: usize,
349 speaker_cond_n_speakers: usize,
350 dtype: candle::DType,
351 vb: candle_nn::VarBuilder,
352 ) -> Result<Self> {
353 let learnt_padding = vb.get(
354 (1, 1, speaker_cond_dim),
355 "condition_provider.conditioners.speaker_wavs.learnt_padding",
356 )?;
357 let mimi_dim = mimi.config().seanet.dimension;
358 let proj = candle_nn::linear_no_bias(
359 mimi_dim,
360 speaker_cond_dim,
361 vb.pp("condition_provider.conditioners.speaker_wavs.output_proj"),
362 )?;
363 Ok(Self {
364 mimi,
365 learnt_padding,
366 proj,
367 n_speakers: speaker_cond_n_speakers,
368 cond_dim: speaker_cond_dim,
369 device: vb.device().clone(),
370 dtype,
371 })
372 }
373
374 pub fn device(&self) -> &candle::Device {
375 &self.device
376 }
377
378 pub fn sample_rate(&self) -> f64 {
379 self.mimi.config().sample_rate
380 }
381
382 pub fn encode(&self, speakers: &[Tensor]) -> Result<Tensor> {
383 if speakers.is_empty() {
384 candle::bail!("empty speakers in encode")
385 }
386 let mut pcms = vec![];
387 for pcm in speakers.iter().take(self.n_speakers) {
388 let stdev = pcm.broadcast_sub(&pcm.mean_all()?)?.sqr()?.mean_all()?.sqrt()?;
389 let pcm = (pcm * 0.08)?.broadcast_div(&stdev)?;
390 pcms.push(pcm)
391 }
392 let n_speakers = pcms.len();
393 let pcm = Tensor::cat(&pcms, 0)?;
394 let mut mimi = self.mimi.clone();
395 mimi.reset_state();
396 let embeddings = mimi.encode_pre_quantize(&pcm)?.t()?.apply(&self.proj)?;
397 let embeddings = if n_speakers < self.n_speakers {
398 let lp =
399 embeddings.narrow(0, 0, 1)?.zeros_like()?.broadcast_add(&self.learnt_padding)?;
400 let mut embs = vec![embeddings];
401 embs.resize(self.n_speakers - n_speakers + 1, lp);
402 Tensor::cat(&embs, 0)?
403 } else {
404 embeddings
405 };
406 let embeddings = embeddings.flatten(0, 1)?.unsqueeze(0)?;
407 let embeddings = crate::tts::add_sin_embeddings(&embeddings)?;
408 embeddings.to_dtype(self.dtype)
409 }
410
411 pub fn empty(&self) -> Result<Tensor> {
412 let embeddings =
413 self.learnt_padding.broadcast_as((1, self.n_speakers * 125, self.cond_dim))?;
414 let embeddings = crate::tts::add_sin_embeddings(&embeddings)?;
415 embeddings.to_dtype(self.dtype)
416 }
417}