rlx_qwen3_tts/code_predictor/
engine.rs1use crate::code_predictor::compiled::CpCompiledEngine;
19use crate::code_predictor::eager::CpEagerModel;
20use crate::config::CodePredictorConfig;
21use crate::load::Qwen3TtsWeightStore;
22use anyhow::{Context, Result, ensure};
23use ndarray::{Array2, ArrayView1};
24use rlx_runtime::Device;
25
26fn cp_force_eager() -> bool {
27 std::env::var("RLX_QWEN3_TTS_CP_EAGER").ok().as_deref() == Some("1")
28}
29
30pub fn cp_use_compiled_for_device(talker_device: Device) -> bool {
32 if cp_force_eager() {
33 return false;
34 }
35 if std::env::var("RLX_QWEN3_TTS_CP_COMPILED").ok().as_deref() == Some("1") {
36 return true;
37 }
38 if crate::gpu_pipeline::gpu_session_enabled(talker_device) {
39 return crate::gpu_pipeline::cp_use_gpu_on_device(talker_device);
40 }
41 talker_device != Device::Cpu && talker_device != Device::Metal
42}
43
44fn cp_execution_device(talker_device: Device) -> Device {
45 if !cp_use_compiled_for_device(talker_device) {
46 Device::Cpu
47 } else {
48 crate::compile_opts::cp_compile_device(talker_device)
49 }
50}
51
52enum CpBackend {
53 Eager(CpEagerModel),
54 Compiled(CpCompiledEngine),
55}
56
57pub struct CodePredictorEngine {
58 talker_device: Device,
59 cp_device: Device,
60 backend: CpBackend,
61 talker_codec: Array2<f32>,
62 talker_codec_flat: Vec<f32>,
63 group_embeds: Vec<Array2<f32>>,
64 group_embed_flat: Vec<Vec<f32>>,
65 lm_heads: Vec<Array2<f32>>,
66 lm_head_flat: Vec<Vec<f32>>,
67 lm_head_vocab: Vec<usize>,
68 hidden: usize,
69}
70
71impl CodePredictorEngine {
72 pub fn open(
73 store: &Qwen3TtsWeightStore,
74 cp: &CodePredictorConfig,
75 device: Device,
76 ) -> Result<Self> {
77 let talker_snap = store.tensor_snapshot(&["talker.model.codec_embedding.weight"])?;
78 let (tc_data, tc_shape) = talker_snap
79 .get("talker.model.codec_embedding.weight")
80 .context("talker codec_embedding")?;
81 let talker_codec_flat = tc_data.clone();
82 let talker_codec =
83 Array2::from_shape_vec((tc_shape[0], tc_shape[1]), talker_codec_flat.clone())?;
84
85 let mut group_embeds = Vec::with_capacity(cp.num_code_groups - 1);
86 let mut group_embed_flat = Vec::with_capacity(cp.num_code_groups - 1);
87 for i in 0..cp.num_code_groups - 1 {
88 let key = format!("talker.code_predictor.model.codec_embedding.{i}.weight");
89 let (data, shape) = store.tensor_snapshot(&[&key])?[&key].clone();
90 group_embeds.push(Array2::from_shape_vec((shape[0], shape[1]), data.clone())?);
91 group_embed_flat.push(data);
92 }
93 let mut lm_heads = Vec::with_capacity(cp.num_code_groups - 1);
94 let mut lm_head_flat = Vec::with_capacity(cp.num_code_groups - 1);
95 let mut lm_head_vocab = Vec::with_capacity(cp.num_code_groups - 1);
96 for i in 0..cp.num_code_groups - 1 {
97 let key = format!("talker.code_predictor.lm_head.{i}.weight");
98 let (data, shape) = store.tensor_snapshot(&[&key])?[&key].clone();
99 lm_head_vocab.push(shape[0]);
100 lm_head_flat.push(data.clone());
101 lm_heads.push(Array2::from_shape_vec((shape[0], shape[1]), data)?);
102 }
103
104 let cp_device = cp_execution_device(device);
105 let backend = if cp_use_compiled_for_device(device) {
106 CpBackend::Compiled(CpCompiledEngine::open(
107 store.model_dir(),
108 store,
109 cp,
110 cp_device,
111 )?)
112 } else {
113 CpBackend::Eager(CpEagerModel::open(store, cp)?)
114 };
115
116 Ok(Self {
117 talker_device: device,
118 cp_device,
119 backend,
120 talker_codec,
121 talker_codec_flat,
122 group_embeds,
123 group_embed_flat,
124 lm_heads,
125 lm_head_flat,
126 lm_head_vocab,
127 hidden: cp.hidden_size,
128 })
129 }
130
131 pub fn is_eager(&self) -> bool {
132 matches!(self.backend, CpBackend::Eager(_))
133 }
134
135 pub fn talker_codec_flat(&self) -> (&[f32], usize) {
139 (&self.talker_codec_flat, self.hidden)
140 }
141
142 pub fn device(&self) -> Device {
143 self.cp_device
144 }
145
146 pub fn cp_backend_label(&self) -> String {
147 match &self.backend {
148 CpBackend::Eager(_) => "CPU eager".into(),
149 CpBackend::Compiled(_) if self.cp_device != self.talker_device => {
150 format!("compiled (CPU, talker {:?})", self.talker_device)
151 }
152 CpBackend::Compiled(_) => format!("compiled ({:?})", self.cp_device),
153 }
154 }
155
156 pub fn warmup(&mut self, max_frames: usize) -> Result<()> {
157 match &mut self.backend {
158 CpBackend::Eager(e) => {
159 let mut hidden = vec![0f32; self.hidden];
160 for (i, v) in hidden.iter_mut().enumerate() {
161 *v = ((i % 17) as f32) * 1e-5;
162 }
163 let _ = e.predict_groups(
164 &self.talker_codec,
165 &self.group_embeds,
166 &self.lm_heads,
167 ArrayView1::from(&hidden),
168 1995,
169 )?;
170 Ok(())
171 }
172 CpBackend::Compiled(c) => c.warmup(max_frames),
173 }
174 }
175
176 pub fn predict_groups_slice(&mut self, talker_hidden: &[f32], group0: u32) -> Result<Vec<u32>> {
177 self.predict_groups(ArrayView1::from(talker_hidden), group0)
178 }
179
180 pub fn predict_groups_fill_emb(
182 &mut self,
183 talker_hidden: &[f32],
184 group0: u32,
185 pad: &[f32],
186 codec_emb: &mut [f32],
187 ) -> Result<Vec<u32>> {
188 ensure!(codec_emb.len() == self.hidden);
189 match &mut self.backend {
190 CpBackend::Eager(e) => e.predict_groups_fill_emb_flat(
191 &self.talker_codec_flat,
192 &self.group_embed_flat,
193 &self.lm_head_flat,
194 &self.lm_head_vocab,
195 ArrayView1::from(talker_hidden),
196 group0,
197 pad,
198 codec_emb,
199 self.hidden,
200 ),
201 CpBackend::Compiled(c) => {
202 let groups = c.predict_groups(
203 &self.talker_codec,
204 &self.group_embeds,
205 &self.lm_heads,
206 ArrayView1::from(talker_hidden),
207 group0,
208 )?;
209 codec_emb.fill(0.0);
210 self.sum_codec_groups_into(&groups, codec_emb)?;
211 for (j, v) in pad.iter().enumerate() {
212 codec_emb[j] += *v;
213 }
214 Ok(groups)
215 }
216 }
217 }
218
219 pub fn predict_groups(
220 &mut self,
221 talker_hidden: ArrayView1<f32>,
222 group0: u32,
223 ) -> Result<Vec<u32>> {
224 ensure!(talker_hidden.len() == self.hidden);
225 match &mut self.backend {
226 CpBackend::Eager(e) => e.predict_groups(
227 &self.talker_codec,
228 &self.group_embeds,
229 &self.lm_heads,
230 talker_hidden,
231 group0,
232 ),
233 CpBackend::Compiled(c) => c.predict_groups(
234 &self.talker_codec,
235 &self.group_embeds,
236 &self.lm_heads,
237 talker_hidden,
238 group0,
239 ),
240 }
241 }
242
243 pub fn sum_codec_groups_into(&self, groups: &[u32], out: &mut [f32]) -> Result<()> {
245 ensure!(out.len() == self.hidden, "codec emb buffer len mismatch");
246 out.fill(0.0);
247 for (gi, &tok) in groups.iter().enumerate() {
248 if gi == 0 {
249 ensure!(
250 (tok as usize) < self.talker_codec.nrows(),
251 "group0 token {tok} oob"
252 );
253 for (j, v) in self.talker_codec.row(tok as usize).iter().enumerate() {
254 out[j] += *v;
255 }
256 } else {
257 let table = &self.group_embeds[gi - 1];
258 ensure!(
259 (tok as usize) < table.nrows(),
260 "token {tok} oob for group {gi}"
261 );
262 for (j, v) in table.row(tok as usize).iter().enumerate() {
263 out[j] += *v;
264 }
265 }
266 }
267 Ok(())
268 }
269
270 pub fn sum_codec_groups(&self, groups: &[u32]) -> Result<Vec<f32>> {
271 let mut emb = vec![0f32; self.hidden];
272 self.sum_codec_groups_into(groups, &mut emb)?;
273 Ok(emb)
274 }
275
276 pub fn cp_step_embeds_from_groups(&self, groups: &[u32]) -> Result<Vec<Vec<f32>>> {
278 use crate::cp_frame::CP_DECODE_BACKBONE_STEPS;
279 ensure!(
280 groups.len() > CP_DECODE_BACKBONE_STEPS,
281 "groups len {} < {}",
282 groups.len(),
283 1 + CP_DECODE_BACKBONE_STEPS
284 );
285 let mut out = Vec::with_capacity(CP_DECODE_BACKBONE_STEPS);
286 for step in 0..CP_DECODE_BACKBONE_STEPS {
287 out.push(self.codec_embed_row(step + 1, groups[step + 1])?);
288 }
289 Ok(out)
290 }
291
292 pub fn codec_embed_row(&self, group_idx: usize, token: u32) -> Result<Vec<f32>> {
293 if group_idx == 0 {
294 ensure!(
295 (token as usize) < self.talker_codec.nrows(),
296 "group0 token {token} oob"
297 );
298 return Ok(self.talker_codec.row(token as usize).to_vec());
299 }
300 let gi = group_idx - 1;
301 ensure!(gi < self.group_embeds.len(), "group_idx {group_idx} oob");
302 let table = &self.group_embeds[gi];
303 ensure!(
304 (token as usize) < table.nrows(),
305 "token {token} oob for group {group_idx}"
306 );
307 Ok(table.row(token as usize).to_vec())
308 }
309}