1use crate::codec_frame::{Qwen3TtsGraphProfiles, Qwen3TtsGraphRole, cp_decode_graph_parts};
19use crate::compile_opts::{cp_compile_device, metal_compile_guard, talker_compile_options};
20use crate::config::CodePredictorConfig;
21use crate::cp_frame::build_qwen3_tts_cp_prefill_two_built;
22use crate::kv_util::commit_kv_layers;
23use crate::load::{Qwen3TtsWeightStore, remap_code_predictor_weights};
24use crate::talker::math::{
25 bucket_decode_hidden_into, last_decode_hidden_into, linear_logits_into, sample_greedy,
26};
27use crate::talker::rope::{rope_prefill_feeds, rope_slice, rope_tables_full};
28use crate::weights::weight_map_from_cache;
29use anyhow::{Result, ensure};
30use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
31use rlx_core::autoregressive::{KvCacheState, kv_from_prefill_outputs, run_bucketed_kv_decode};
32use rlx_core::flow_util::compile_cache_ensure_built_with_options;
33use rlx_flow::CompileProfile;
34use rlx_runtime::Device;
35use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
36use std::path::Path;
37use std::sync::Arc;
38
39const CP_PREFILL_SEQ: usize = 2;
40const CP_DECODE_BUCKET_MAX: u64 = 32;
41const CP_ROPE_TABLE_LEN: usize = 4096;
43
44pub struct CpCompiledEngine {
45 qwen3: rlx_qwen3::Qwen3Config,
46 session_device: Device,
48 compile_device: Device,
49 hidden: usize,
50 kv_dim: usize,
51 n_layers: usize,
52 head_half: usize,
53 inv_freq: Vec<f64>,
54 weights: Arc<crate::load::TensorSnapshot>,
55 prefill_profile: CompileProfile,
56 decode_profile: CompileProfile,
57 past_len: usize,
58 kv: KvCacheState,
59 prefill_cache: CompileCache,
60 decode_cache: BucketedCompileCache,
61 prefill_scratch: Vec<f32>,
62 decode_embed: Vec<f32>,
63 hidden_row: Vec<f32>,
64 last_raw_hidden: Vec<f32>,
65 logits: Vec<f32>,
66 mask_buf: Vec<f32>,
67}
68
69fn cp_prefill_rope_feeds(
71 inv_freq: &[f64],
72 positions: &[usize],
73 head_dim: usize,
74 rope_table_len: usize,
75 head_half: usize,
76) -> (Vec<f32>, Vec<f32>) {
77 let (mut cos, mut sin) = rope_tables_full(inv_freq, rope_table_len, head_dim);
78 let (seq_cos, seq_sin) = rope_prefill_feeds(inv_freq, positions, head_dim);
79 for t in 0..positions.len() {
80 let off = t * head_half;
81 cos[off..off + head_half].copy_from_slice(&seq_cos[off..off + head_half]);
82 sin[off..off + head_half].copy_from_slice(&seq_sin[off..off + head_half]);
83 }
84 (cos, sin)
85}
86
87fn cp_compile_guard<R, F>(session_device: Device, compile_device: Device, f: F) -> R
88where
89 F: FnOnce() -> R,
90{
91 if compile_device == Device::Cpu {
92 f()
93 } else {
94 metal_compile_guard(session_device, f)
95 }
96}
97
98fn bucket_decode_mask_into(past_seq: usize, upper: usize, out: &mut Vec<f32>) {
99 out.resize(upper + 1, 0.0);
100 for (i, slot) in out.iter_mut().enumerate().take(upper + 1) {
101 *slot = if i < past_seq || i == upper { 1.0 } else { 0.0 };
102 }
103}
104
105impl CpCompiledEngine {
106 pub fn open(
107 model_dir: &Path,
108 store: &Qwen3TtsWeightStore,
109 cp: &CodePredictorConfig,
110 device: Device,
111 ) -> Result<Self> {
112 let mut wm = store.load_code_predictor_backbone()?;
113 let weights = remap_code_predictor_weights(&mut wm)?;
114 let compile_device = cp_compile_device(device);
115 let profiles = Qwen3TtsGraphProfiles::for_role(
116 model_dir,
117 Qwen3TtsGraphRole::CodePredictor,
118 compile_device,
119 );
120 let prefill = profiles.prefill;
121 let decode = profiles.decode;
122 let mut qwen3 = cp.to_qwen3_config();
123 qwen3.max_position_embeddings = qwen3.max_position_embeddings.min(CP_ROPE_TABLE_LEN);
124 let hidden = cp.hidden_size;
125 let kv_dim = qwen3.kv_proj_dim();
126 let n_layers = cp.num_hidden_layers;
127 let head_half = cp.head_dim / 2;
128 let inv_freq = crate::talker::rope::build_inv_freq(cp.head_dim, cp.rope_theta);
129 Ok(Self {
130 qwen3,
131 session_device: device,
132 compile_device,
133 hidden,
134 kv_dim,
135 n_layers,
136 head_half,
137 inv_freq,
138 weights: Arc::new(weights),
139 prefill_profile: prefill,
140 decode_profile: decode,
141 past_len: 0,
142 kv: KvCacheState {
143 past_len: 0,
144 layers_k: vec![Vec::new(); n_layers],
145 layers_v: vec![Vec::new(); n_layers],
146 },
147 prefill_cache: CompileCache::new(compile_device, 4),
148 decode_cache: BucketedCompileCache::power_of_two_ladder(
149 compile_device,
150 1,
151 CP_DECODE_BUCKET_MAX,
152 ),
153 prefill_scratch: vec![0f32; hidden * CP_PREFILL_SEQ],
154 decode_embed: vec![0f32; hidden],
155 hidden_row: vec![0f32; hidden],
156 last_raw_hidden: Vec::new(),
157 logits: vec![0f32; cp.vocab_size],
158 mask_buf: Vec::new(),
159 })
160 }
161
162 #[doc(hidden)]
163 pub fn last_raw_hidden(&self) -> &[f32] {
164 &self.last_raw_hidden
165 }
166
167 #[doc(hidden)]
168 pub fn export_kv_state(&self) -> (KvCacheState, usize) {
169 (self.kv.clone(), self.past_len)
170 }
171
172 #[doc(hidden)]
173 pub fn import_kv_state(&mut self, kv: KvCacheState, past_len: usize) {
174 self.kv = kv;
175 self.past_len = past_len;
176 }
177
178 pub fn warmup(&mut self, max_frames: usize) -> Result<()> {
179 let mut embeds = Array2::<f32>::zeros((CP_PREFILL_SEQ, self.hidden));
180 embeds[[0, 0]] = 1e-4;
181 self.reset_kv();
182 self.prefill(embeds.view())?;
183 if crate::synth_opts::lazy_talk_buckets()
184 && !crate::synth_opts::auto_precompile_horizon(max_frames)
185 {
186 let emb = vec![0f32; self.hidden];
187 let _ = self.decode_step(ArrayView1::from(&emb))?;
188 } else {
189 self.precompile_decode_buckets()?;
190 }
191 Ok(())
192 }
193
194 fn precompile_decode_buckets(&mut self) -> Result<()> {
196 let keys: Vec<u64> = self
197 .decode_cache
198 .buckets()
199 .map(|r| r.end.saturating_sub(1))
200 .filter(|&k| k <= CP_DECODE_BUCKET_MAX)
201 .collect();
202 let opts = talker_compile_options(&self.decode_profile, self.compile_device);
203 for &key in &keys {
204 let weights = Arc::clone(&self.weights);
205 let qwen3 = self.qwen3.clone();
206 let decode_profile = self.decode_profile.clone();
207 cp_compile_guard(self.session_device, self.compile_device, || {
208 let _ = self.decode_cache.ensure_graph_with_params(
209 key,
210 move |upper| {
211 cp_decode_graph_parts(&qwen3, weights.as_ref(), &decode_profile, upper)
212 .expect("cp decode graph")
213 },
214 &opts,
215 );
216 });
217 }
218 Ok(())
219 }
220
221 pub fn reset_kv(&mut self) {
222 self.past_len = 0;
223 self.kv = KvCacheState {
224 past_len: 0,
225 layers_k: vec![Vec::new(); self.n_layers],
226 layers_v: vec![Vec::new(); self.n_layers],
227 };
228 }
229
230 pub fn prefill(&mut self, embeds: ArrayView2<f32>) -> Result<Array2<f32>> {
231 let (seq, h) = embeds.dim();
232 ensure!(h == self.hidden, "cp embed hidden mismatch");
233 ensure!(
234 seq <= CP_PREFILL_SEQ,
235 "cp prefill seq {seq} > {CP_PREFILL_SEQ}"
236 );
237 let flat: Vec<f32> = embeds.iter().copied().collect();
238 let positions: Vec<usize> = (0..seq).collect();
239 let rope_table_len = self.qwen3.max_position_embeddings;
240 let (rope_cos, rope_sin) = cp_prefill_rope_feeds(
241 &self.inv_freq,
242 &positions,
243 self.qwen3.head_dim,
244 rope_table_len,
245 self.head_half,
246 );
247 let opts = talker_compile_options(&self.prefill_profile, self.compile_device);
248 let key = ((1u64) << 32) | (seq as u64);
249 let qwen3 = self.qwen3.clone();
250 let weights = Arc::clone(&self.weights);
251 let profile = self.prefill_profile.clone();
252 let built = {
253 let mut wm = weight_map_from_cache(weights.as_ref())?;
254 if seq == crate::cp_frame::CP_PREFILL_TWO {
255 build_qwen3_tts_cp_prefill_two_built(
256 &qwen3,
257 &mut wm,
258 &profile,
259 Some(rope_cos),
260 Some(rope_sin),
261 )?
262 } else {
263 crate::codec_frame::build_qwen3_tts_prefill_built(
264 &qwen3,
265 &mut wm,
266 seq,
267 &profile,
268 Some(rope_cos),
269 Some(rope_sin),
270 )?
271 }
272 };
273 let compiled = cp_compile_guard(self.session_device, self.compile_device, || {
274 compile_cache_ensure_built_with_options(&mut self.prefill_cache, key, built, &opts)
275 })?;
276 let outputs = compiled.run(&[("inputs_embeds", flat.as_slice())]);
277 let (hidden_out, kv) =
278 kv_from_prefill_outputs(outputs, 1, seq, self.kv_dim, self.n_layers)?;
279 self.kv = kv;
280 self.past_len = seq;
281 let rows = hidden_out.len() / self.hidden;
282 Ok(Array2::from_shape_vec((rows, self.hidden), hidden_out)?)
283 }
284
285 pub fn decode_step(&mut self, embed: ArrayView1<f32>) -> Result<Array1<f32>> {
286 ensure!(embed.len() == self.hidden);
287 self.decode_embed.copy_from_slice(embed.as_slice().unwrap());
288 cp_compile_guard(self.session_device, self.compile_device, || {
289 self.run_decode_step_inner()
290 })?;
291 Ok(Array1::from_vec(self.hidden_row.clone()))
292 }
293
294 fn run_decode_step_inner(&mut self) -> Result<()> {
295 let past_seq = self.past_len;
296 let pos = past_seq;
297 let (cos, sin) = rope_slice(&self.inv_freq, pos, self.qwen3.head_dim);
298 let upper = self
299 .decode_cache
300 .bucket_for(past_seq as u64)
301 .map(|idx| {
302 self.decode_cache
303 .buckets()
304 .nth(idx)
305 .map(|r| (r.end - 1) as usize)
306 .unwrap_or(past_seq)
307 })
308 .unwrap_or(past_seq);
309 bucket_decode_mask_into(past_seq, upper, &mut self.mask_buf);
310 let fixed = [
311 CacheRunInput {
312 name: "inputs_embeds",
313 data: self.decode_embed.as_slice(),
314 row_inner: None,
315 },
316 CacheRunInput {
317 name: "rope_cos",
318 data: &cos,
319 row_inner: None,
320 },
321 CacheRunInput {
322 name: "rope_sin",
323 data: &sin,
324 row_inner: None,
325 },
326 CacheRunInput {
327 name: "mask",
328 data: self.mask_buf.as_slice(),
329 row_inner: None,
330 },
331 ];
332 let opts = talker_compile_options(&self.decode_profile, self.compile_device);
333 let weights = Arc::clone(&self.weights);
334 let qwen3 = self.qwen3.clone();
335 let decode_profile = self.decode_profile.clone();
336 let (hidden_vec, new_k, new_v) = run_bucketed_kv_decode(
337 &mut self.decode_cache,
338 past_seq,
339 &self.kv,
340 self.kv_dim,
341 self.n_layers,
342 &fixed,
343 move |upper| {
344 cp_decode_graph_parts(&qwen3, weights.as_ref(), &decode_profile, upper)
345 .expect("cp decode graph")
346 },
347 &opts,
348 )?;
349 commit_kv_layers(&mut self.kv.layers_k, &mut self.kv.layers_v, &new_k, &new_v);
350 self.kv.past_len = past_seq + 1;
351 self.past_len += 1;
352 self.last_raw_hidden = hidden_vec.clone();
353 bucket_decode_hidden_into(&hidden_vec, self.hidden, &mut self.hidden_row)?;
354 Ok(())
355 }
356
357 fn prefill_stacked(&mut self, seq: usize) -> Result<()> {
358 ensure!(seq <= CP_PREFILL_SEQ);
359 let flat_len = seq * self.hidden;
360 let flat = self.prefill_scratch[..flat_len].to_vec();
361 self.run_prefill_flat(&flat, seq)
362 }
363
364 fn run_prefill_flat(&mut self, flat: &[f32], seq: usize) -> Result<()> {
365 ensure!(flat.len() == seq * self.hidden);
366 let positions: Vec<usize> = (0..seq).collect();
367 let rope_table_len = self.qwen3.max_position_embeddings;
368 let (rope_cos, rope_sin) = cp_prefill_rope_feeds(
369 &self.inv_freq,
370 &positions,
371 self.qwen3.head_dim,
372 rope_table_len,
373 self.head_half,
374 );
375 let opts = talker_compile_options(&self.prefill_profile, self.compile_device);
376 let key = ((1u64) << 32) | (seq as u64);
377 let qwen3 = self.qwen3.clone();
378 let weights = Arc::clone(&self.weights);
379 let profile = self.prefill_profile.clone();
380 let built = {
381 let mut wm = weight_map_from_cache(weights.as_ref())?;
382 if seq == crate::cp_frame::CP_PREFILL_TWO {
383 build_qwen3_tts_cp_prefill_two_built(
384 &qwen3,
385 &mut wm,
386 &profile,
387 Some(rope_cos),
388 Some(rope_sin),
389 )?
390 } else {
391 crate::codec_frame::build_qwen3_tts_prefill_built(
392 &qwen3,
393 &mut wm,
394 seq,
395 &profile,
396 Some(rope_cos),
397 Some(rope_sin),
398 )?
399 }
400 };
401 let compiled = cp_compile_guard(self.session_device, self.compile_device, || {
402 compile_cache_ensure_built_with_options(&mut self.prefill_cache, key, built, &opts)
403 })?;
404 let outputs = compiled.run(&[("inputs_embeds", flat)]);
405 let (hidden_out, kv) =
406 kv_from_prefill_outputs(outputs, 1, seq, self.kv_dim, self.n_layers)?;
407 self.kv = kv;
408 self.past_len = seq;
409 last_decode_hidden_into(&hidden_out, self.hidden, &mut self.hidden_row)?;
410 Ok(())
411 }
412
413 pub fn predict_groups(
414 &mut self,
415 talker_codec: &Array2<f32>,
416 group_embeds: &[Array2<f32>],
417 lm_heads: &[Array2<f32>],
418 talker_hidden: ArrayView1<f32>,
419 group0: u32,
420 ) -> Result<Vec<u32>> {
421 cp_compile_guard(self.session_device, self.compile_device, || {
422 self.predict_groups_inner(talker_codec, group_embeds, lm_heads, talker_hidden, group0)
423 })
424 }
425
426 fn predict_groups_inner(
427 &mut self,
428 talker_codec: &Array2<f32>,
429 group_embeds: &[Array2<f32>],
430 lm_heads: &[Array2<f32>],
431 talker_hidden: ArrayView1<f32>,
432 group0: u32,
433 ) -> Result<Vec<u32>> {
434 ensure!(talker_hidden.len() == self.hidden);
435 self.reset_kv();
436 let h = self.hidden;
437 self.prefill_scratch[..h].copy_from_slice(talker_hidden.as_slice().unwrap());
438 let e0 = talker_codec.row(group0 as usize);
439 self.prefill_scratch[h..h * 2].copy_from_slice(e0.as_slice().unwrap());
440 self.prefill_stacked(CP_PREFILL_SEQ)?;
441 let mut codes = vec![group0];
442 for step in 0..lm_heads.len() {
443 linear_logits_into(
444 ArrayView1::from(&self.hidden_row),
445 lm_heads[step].view(),
446 &mut self.logits,
447 )?;
448 let tok = sample_greedy(&self.logits);
449 codes.push(tok);
450 if step + 1 < lm_heads.len() {
451 let row = group_embeds[step].row(tok as usize);
452 self.decode_embed.copy_from_slice(row.as_slice().unwrap());
453 self.run_decode_step_inner()?;
454 }
455 }
456 Ok(codes)
457 }
458}