1use candle::{IndexOp, Tensor};
6use candle_transformers::generation::LogitsProcessor;
7
8use crate::transformer::CaSrc;
9
10pub const UNGENERATED: u32 = u32::MAX;
11
12#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
13pub struct Config {
14 pub generated_audio_codebooks: usize,
15 pub input_audio_codebooks: usize,
16 pub audio_vocab_size: usize,
17 pub acoustic_delay: usize,
18 pub text_pad_token: u32,
19 pub text_eop_token: u32,
20 pub text_start_token: u32,
21}
22
23impl Config {
24 pub fn v0_1() -> Self {
25 Self {
26 generated_audio_codebooks: 8,
27 input_audio_codebooks: 8,
28 audio_vocab_size: 2049,
29 acoustic_delay: 2,
30 text_eop_token: 0,
31 text_pad_token: 3,
32 text_start_token: 32000,
33 }
34 }
35
36 pub fn v0_1_two_ways() -> Self {
37 Self {
38 generated_audio_codebooks: 16,
39 input_audio_codebooks: 0,
40 audio_vocab_size: 2049,
41 acoustic_delay: 2,
42 text_eop_token: 0,
43 text_pad_token: 3,
44 text_start_token: 32000,
45 }
46 }
47
48 pub fn v0_1_one_way() -> Self {
49 Self {
50 generated_audio_codebooks: 8,
51 input_audio_codebooks: 0,
52 audio_vocab_size: 2049,
53 acoustic_delay: 2,
54 text_eop_token: 0,
55 text_pad_token: 3,
56 text_start_token: 32000,
57 }
58 }
59
60 pub fn audio_pad_token(&self) -> u32 {
61 self.audio_vocab_size as u32 - 1
62 }
63
64 pub fn total_audio_codebooks(&self) -> usize {
65 self.generated_audio_codebooks + self.input_audio_codebooks
66 }
67}
68
69pub struct State {
70 model: crate::lm::LmModel,
71 audio_tokens: Vec<Vec<u32>>,
72 text_tokens: Vec<u32>,
73 audio_lp: LogitsProcessor,
74 text_lp: LogitsProcessor,
75 step_idx: usize,
76 pad_mult: Option<f32>,
77 repetition_penalty: Option<(usize, f32)>,
79 forced_audio_tokens: crate::lm::ForcedAudioTokens,
80 user_rating: u32,
81 cfg_alpha: Option<f64>,
82 config: Config,
83}
84
85impl State {
86 #[allow(clippy::too_many_arguments)]
87 pub fn new(
88 model: crate::lm::LmModel,
89 max_step_idx: usize,
90 audio_lp: LogitsProcessor,
91 text_lp: LogitsProcessor,
92 pad_mult: Option<f32>,
93 repetition_penalty: Option<(usize, f32)>,
94 cfg_alpha: Option<f64>,
95 config: Config,
96 ) -> Self {
97 let audio_tokens: Vec<Vec<u32>> = vec![
98 vec![UNGENERATED; config.total_audio_codebooks()];
99 max_step_idx + config.acoustic_delay
100 ];
101 let text_tokens = vec![UNGENERATED; max_step_idx + config.acoustic_delay];
102 let forced_audio_tokens = crate::lm::ForcedAudioTokens::new(
103 config.acoustic_delay,
104 config.audio_pad_token(),
105 &[8, 8],
106 );
107 Self {
108 model,
109 audio_tokens,
110 text_tokens,
111 audio_lp,
112 text_lp,
113 step_idx: 0,
114 pad_mult,
115 repetition_penalty,
116 forced_audio_tokens,
117 user_rating: 0, cfg_alpha,
119 config,
120 }
121 }
122
123 pub fn step_idx(&self) -> usize {
124 self.step_idx
125 }
126
127 fn audio_pad_token(&self) -> u32 {
128 self.config.audio_pad_token()
129 }
130
131 pub fn config(&self) -> &Config {
132 &self.config
133 }
134
135 pub fn user_rating(&self) -> u32 {
136 self.user_rating
137 }
138 pub fn set_user_rating(&mut self, grade: u32) {
139 self.user_rating = grade
140 }
141
142 fn apply_repetition_penalty(&self, logits: Tensor) -> candle::Result<Tensor> {
143 let logits = match self.repetition_penalty {
144 None => logits,
145 Some((_, 1.)) => logits,
146 Some((context_size, penalty)) => {
147 let device = logits.device();
148 let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
149 let mut already_seen = std::collections::HashSet::new();
150 let mut non_pad_tokens = 0;
151 for &token_id in self.text_tokens(false).iter().rev() {
152 if token_id == self.config.text_pad_token
153 || token_id == self.config.text_eop_token
154 || token_id == self.config.text_start_token
155 {
156 continue;
157 }
158 if non_pad_tokens >= context_size {
161 break;
162 }
163 non_pad_tokens += 1;
164
165 if already_seen.contains(&token_id) {
166 continue;
167 }
168
169 already_seen.insert(token_id);
170 if let Some(logit) = logits.get_mut(token_id as usize) {
171 if *logit >= 0. {
172 *logit /= penalty
173 } else {
174 *logit *= penalty
175 }
176 }
177 }
178 let logits_len = logits.len();
179 Tensor::from_vec(logits, logits_len, device)?
180 }
181 };
182 Ok(logits)
183 }
184
185 pub fn step_(
188 &mut self,
189 text_token: Option<u32>,
190 input_audio_tokens: &[u32],
191 force_text_token: Option<u32>,
192 ca_src: Option<&CaSrc>,
193 conditions: Option<&crate::conditioner::Condition>,
194 ) -> candle::Result<u32> {
195 let mut codes = Vec::with_capacity(self.config.total_audio_codebooks());
196 let dev = self.model.device();
197 for (c_idx, &t) in input_audio_tokens.iter().enumerate() {
198 self.audio_tokens[self.step_idx][c_idx + self.config.generated_audio_codebooks] = t
199 }
200 let batch_size = if self.cfg_alpha.is_some() { 2 } else { 1 };
201 for codebook in 0..self.config.total_audio_codebooks() {
202 let t = if codebook == 0 || codebook == self.config.generated_audio_codebooks {
203 if self.step_idx == 0 {
204 self.audio_pad_token()
205 } else {
206 self.audio_tokens[self.step_idx - 1][codebook]
207 }
208 } else if self.step_idx <= self.config.acoustic_delay {
209 self.audio_pad_token()
210 } else {
211 self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1][codebook]
212 };
213 if t == UNGENERATED {
214 candle::bail!("internal error, ungenerated {} {codebook}", self.step_idx)
215 }
216 let t = Tensor::from_vec(vec![t; batch_size], (batch_size, 1), dev)?;
217 codes.push(Some(t))
218 }
219 let text_token = match text_token {
220 Some(text_token) => {
221 Some(Tensor::from_vec(vec![text_token; batch_size], (batch_size, 1), dev)?)
222 }
223 None => None,
224 };
225 let (text_logits, ys) = match ca_src.as_ref() {
226 None => {
227 let (logits, ys) =
228 self.model.forward_cond(text_token, codes, conditions, &().into())?;
229 let logits = match self.cfg_alpha {
230 None => logits.i((0, 0))?,
231 Some(a) => match logits.dim(0)? {
232 2 => ((logits.i((0, 0))? * a)? - (logits.i((1, 0))? * (a - 1.))?)?,
233 b_size => candle::bail!("unexpected batch size {b_size}"),
234 },
235 };
236 (logits, ys)
237 }
238 Some(ca_src) => {
239 if self.cfg_alpha.is_some() {
240 candle::bail!("cfg is not supported with cross attention")
241 }
242 let (logits, ys) =
243 self.model.forward_ca(text_token, codes, ca_src, None, &().into())?;
244 (logits.i((0, 0))?, ys)
245 }
246 };
247 let text_logits = self.apply_repetition_penalty(text_logits)?;
248 let text_token = match force_text_token {
249 Some(tt) => tt,
250 None => self.text_lp.sample_f(&text_logits, |prs| {
251 if let Some(pad_mult) = self.pad_mult.as_ref() {
252 prs[self.config.text_pad_token as usize] *= f32::exp(*pad_mult);
253 }
254 })?,
255 };
256 self.text_tokens[self.step_idx] = text_token;
257 let last_audio_tokens = match self.cfg_alpha {
258 None => self.model.depformer_sample(
259 &ys,
260 Some(text_token),
261 self.forced_audio_tokens.forced_tokens(self.step_idx),
262 &mut self.audio_lp,
263 )?,
264 Some(cfg_alpha) => self.model.depformer_sample_cfg(
265 &ys,
266 cfg_alpha,
267 Some(text_token),
268 self.forced_audio_tokens.forced_tokens(self.step_idx),
269 &mut self.audio_lp,
270 )?,
271 };
272 let audio_pad_token = self.audio_pad_token();
273 for c_idx in 0..self.config.generated_audio_codebooks {
274 let delay = if c_idx == 0 || c_idx == self.config.generated_audio_codebooks {
275 0
276 } else {
277 self.config.acoustic_delay
278 };
279 let pos = &mut self.audio_tokens[self.step_idx.saturating_sub(delay)][c_idx];
280 *pos = last_audio_tokens.as_ref().map_or(audio_pad_token, |l| l[c_idx]);
283 }
284 self.step_idx += 1;
285 if self.step_idx >= self.audio_tokens.len() {
286 candle::bail!("max step-idx reached")
287 }
288 Ok(text_token)
289 }
290
291 pub fn step_without_ca_src(
292 &mut self,
293 text_token: u32,
294 input_audio_tokens: &[u32],
295 force_text_token: Option<u32>,
296 ) -> candle::Result<u32> {
297 self.step_(Some(text_token), input_audio_tokens, force_text_token, None, None)
298 }
299
300 pub fn step(
301 &mut self,
302 text_token: u32,
303 input_audio_tokens: &[u32],
304 force_text_token: Option<u32>,
305 ca_src: Option<&CaSrc>,
306 ) -> candle::Result<u32> {
307 self.step_(Some(text_token), input_audio_tokens, force_text_token, ca_src, None)
308 }
309
310 pub fn audio_tokens(&self, include_all: bool) -> &[Vec<u32>] {
313 if include_all {
314 &self.audio_tokens
315 } else {
316 let max_idx = usize::min(self.step_idx, self.audio_tokens.len());
317 &self.audio_tokens[..max_idx]
318 }
319 }
320
321 pub fn text_tokens(&self, include_all: bool) -> &[u32] {
322 if include_all {
323 &self.text_tokens
324 } else {
325 let max_idx = usize::min(self.step_idx, self.text_tokens.len());
326 &self.text_tokens[..max_idx]
327 }
328 }
329
330 pub fn last_audio_tokens(&self) -> Option<Vec<u32>> {
331 if self.step_idx <= self.config.acoustic_delay {
332 None
333 } else {
334 let audio_tokens = &self.audio_tokens[self.step_idx - self.config.acoustic_delay - 1];
336 if audio_tokens.iter().any(|v| *v as usize >= self.config.audio_vocab_size - 1) {
337 None
338 } else {
339 Some(audio_tokens.clone())
340 }
341 }
342 }
343}