1use candle::{IndexOp, Result, Tensor};
15use candle_transformers::generation::LogitsProcessor;
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq)]
18pub enum Token {
19 Set(u32),
20 Ungenerated,
21 LiteralZero,
22}
23
24#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
25pub struct Config {
26 pub audio_delays: Vec<usize>,
27 pub audio_vocab_size: usize,
28 pub text_pad_token: u32,
29 pub text_eop_token: u32,
30 pub text_start_token: u32,
31}
32
33impl Config {
34 pub fn audio_pad_token(&self) -> u32 {
35 self.audio_vocab_size as u32 - 1
36 }
37
38 pub fn audio_codebooks(&self) -> usize {
39 self.audio_delays.len()
40 }
41
42 pub fn max_audio_delay(&self) -> usize {
43 self.audio_delays.iter().max().cloned().unwrap_or(0)
44 }
45}
46
47pub struct State {
48 model: crate::lm::LmModel,
49 audio_tokens: Vec<Vec<Token>>,
50 text_tokens: Vec<Token>,
51 audio_lp: LogitsProcessor,
52 text_lp: LogitsProcessor,
53 step_idx: usize,
54 config: Config,
55}
56
57impl State {
58 pub fn new(
59 model: crate::lm::LmModel,
60 max_step_idx: usize,
61 audio_lp: LogitsProcessor,
62 text_lp: LogitsProcessor,
63 config: Config,
64 ) -> Self {
65 let total_len = max_step_idx + config.max_audio_delay();
67 let audio_tokens = vec![vec![Token::Ungenerated; config.audio_codebooks()]; total_len];
68 let text_tokens = vec![Token::Ungenerated; total_len];
69 Self { model, audio_tokens, text_tokens, audio_lp, text_lp, step_idx: 0, config }
70 }
71
72 pub fn step_idx(&self) -> usize {
73 self.step_idx
74 }
75
76 pub fn audio_pad_token(&self) -> u32 {
77 self.config.audio_pad_token()
78 }
79
80 pub fn config(&self) -> &Config {
81 &self.config
82 }
83
84 pub fn set_audio_tokens(&mut self, audio_tokens: &[Option<Token>]) -> Result<()> {
85 for (s, at) in self.audio_tokens[self.step_idx].iter_mut().zip(audio_tokens.iter()) {
86 if let Some(at) = at {
87 *s = *at
88 }
89 }
90 Ok(())
91 }
92
93 pub fn step(&mut self, conditions: Option<&crate::conditioner::Condition>) -> Result<()> {
94 let dev = self.model.device();
95
96 let mut forced_audio_tokens = Vec::with_capacity(self.config.audio_codebooks());
97 for (codebook, &delay) in self.config.audio_delays.iter().enumerate() {
98 let forced_token = if self.step_idx < delay {
99 Some(self.audio_pad_token())
100 } else {
101 match self.audio_tokens[self.step_idx - delay][codebook] {
102 Token::Ungenerated | Token::LiteralZero => None,
103 Token::Set(v) => Some(v),
104 }
105 };
106 forced_audio_tokens.push(forced_token);
107 }
108
109 let mut codes = Vec::with_capacity(self.config.audio_codebooks());
110 for (codebook, &delay) in self.config.audio_delays.iter().enumerate() {
111 let t = if self.step_idx <= delay {
112 Some(self.audio_pad_token())
113 } else {
114 match self.audio_tokens[self.step_idx - delay - 1][codebook] {
115 Token::LiteralZero => None,
116 Token::Set(v) => Some(v),
117 Token::Ungenerated => {
118 candle::bail!("internal error, ungenerated {} {codebook}", self.step_idx)
119 }
120 }
121 };
122 let t = match t {
123 None => None,
124 Some(t) => Some(Tensor::from_vec(vec![t; 1], (1, 1), dev)?),
125 };
126 codes.push(t)
127 }
128 let text_token = if self.step_idx == 0 {
129 Some(self.config.text_start_token)
130 } else {
131 match self.text_tokens[self.step_idx - 1] {
132 Token::LiteralZero => None,
133 Token::Set(t) => Some(t),
134 Token::Ungenerated => {
135 candle::bail!("internal error, ungenerated {} text", self.step_idx)
136 }
137 }
138 };
139 let text_token = match text_token {
140 None => None,
141 Some(t) => Some(Tensor::from_vec(vec![t; 1], (1, 1), dev)?),
142 };
143 let (text_logits, ys) =
144 self.model.forward_cond(text_token, codes, conditions, &().into())?;
145 let text_token = match self.text_tokens[self.step_idx] {
146 Token::Ungenerated => {
147 let t = self.text_lp.sample(&text_logits.i((0, 0))?)?;
148 self.text_tokens[self.step_idx] = Token::Set(t);
149 Some(t)
150 }
151 Token::Set(t) => Some(t),
152 Token::LiteralZero => None,
153 };
154 let audio_tokens = self.model.depformer_sample(
155 &ys,
156 text_token,
157 &forced_audio_tokens,
158 &mut self.audio_lp,
159 )?;
160 if let Some(audio_tokens) = audio_tokens {
161 for (codebook, audio_token) in audio_tokens.into_iter().enumerate() {
162 let delay = self.config.audio_delays[codebook];
163 if self.step_idx < delay {
164 continue;
165 }
166 let pos = &mut self.audio_tokens[self.step_idx - delay][codebook];
167 if *pos == Token::Ungenerated {
168 *pos = Token::Set(audio_token)
169 }
170 }
171 }
172 self.step_idx += 1;
173 if self.step_idx >= self.audio_tokens.len() {
174 candle::bail!("max step-idx reached")
175 }
176 Ok(())
177 }
178
179 pub fn last_text_token(&self) -> Result<Option<u32>> {
180 if self.step_idx == 0 {
181 Ok(None)
182 } else {
183 match self.text_tokens[self.step_idx - 1] {
184 Token::Set(t) => Ok(Some(t)),
185 Token::LiteralZero => Ok(None),
186 Token::Ungenerated => {
187 candle::bail!("internal error, ungenerated step {}, text", self.step_idx)
188 }
189 }
190 }
191 }
192
193 pub fn last_audio_tokens(&self) -> Result<Option<Vec<u32>>> {
194 let max_audio_delay = self.config.max_audio_delay();
195 if self.step_idx <= max_audio_delay {
196 Ok(None)
197 } else {
198 let mut audio_tokens = vec![];
199 for (cb, audio_token) in
200 self.audio_tokens[self.step_idx - max_audio_delay - 1].iter().enumerate()
201 {
202 match audio_token {
203 Token::LiteralZero => return Ok(None),
204 Token::Set(s) => audio_tokens.push(*s),
205 Token::Ungenerated => {
206 candle::bail!("internal error, ungenerated step {}, cb {cb}", self.step_idx)
207 }
208 }
209 }
210 Ok(Some(audio_tokens))
211 }
212 }
213}