Skip to main content

rlx_qwen3_tts/code_predictor/
compiled.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compiled code-predictor (5-layer Qwen3, `inputs_embeds` + KV decode).
17
18use crate::codec_frame::{Qwen3TtsGraphProfiles, Qwen3TtsGraphRole, cp_decode_graph_parts};
19use crate::compile_opts::{cp_compile_device, metal_compile_guard, talker_compile_options};
20use crate::config::CodePredictorConfig;
21use crate::cp_frame::build_qwen3_tts_cp_prefill_two_built;
22use crate::kv_util::commit_kv_layers;
23use crate::load::{Qwen3TtsWeightStore, remap_code_predictor_weights};
24use crate::talker::math::{
25    bucket_decode_hidden_into, last_decode_hidden_into, linear_logits_into, sample_greedy,
26};
27use crate::talker::rope::{rope_prefill_feeds, rope_slice, rope_tables_full};
28use crate::weights::weight_map_from_cache;
29use anyhow::{Result, ensure};
30use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
31use rlx_core::autoregressive::{KvCacheState, kv_from_prefill_outputs, run_bucketed_kv_decode};
32use rlx_core::flow_util::compile_cache_ensure_built_with_options;
33use rlx_flow::CompileProfile;
34use rlx_runtime::Device;
35use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
36use std::path::Path;
37use std::sync::Arc;
38
39const CP_PREFILL_SEQ: usize = 2;
40const CP_DECODE_BUCKET_MAX: u64 = 32;
41/// Match eager CP rope tables (`CpEagerModel` caps at 4096); HF lists 65536 but AR depth ≤ 32.
42const CP_ROPE_TABLE_LEN: usize = 4096;
43
44pub struct CpCompiledEngine {
45    qwen3: rlx_qwen3::Qwen3Config,
46    /// Session device (Metal/CPU/CUDA); compile caches may run on CPU when Metal is session.
47    session_device: Device,
48    compile_device: Device,
49    hidden: usize,
50    kv_dim: usize,
51    n_layers: usize,
52    head_half: usize,
53    inv_freq: Vec<f64>,
54    weights: Arc<crate::load::TensorSnapshot>,
55    prefill_profile: CompileProfile,
56    decode_profile: CompileProfile,
57    past_len: usize,
58    kv: KvCacheState,
59    prefill_cache: CompileCache,
60    decode_cache: BucketedCompileCache,
61    prefill_scratch: Vec<f32>,
62    decode_embed: Vec<f32>,
63    hidden_row: Vec<f32>,
64    last_raw_hidden: Vec<f32>,
65    logits: Vec<f32>,
66    mask_buf: Vec<f32>,
67}
68
69/// Full `[max_pos, head_half]` rope tables with active prefill positions at the front (talker pattern).
70fn cp_prefill_rope_feeds(
71    inv_freq: &[f64],
72    positions: &[usize],
73    head_dim: usize,
74    rope_table_len: usize,
75    head_half: usize,
76) -> (Vec<f32>, Vec<f32>) {
77    let (mut cos, mut sin) = rope_tables_full(inv_freq, rope_table_len, head_dim);
78    let (seq_cos, seq_sin) = rope_prefill_feeds(inv_freq, positions, head_dim);
79    for t in 0..positions.len() {
80        let off = t * head_half;
81        cos[off..off + head_half].copy_from_slice(&seq_cos[off..off + head_half]);
82        sin[off..off + head_half].copy_from_slice(&seq_sin[off..off + head_half]);
83    }
84    (cos, sin)
85}
86
87fn cp_compile_guard<R, F>(session_device: Device, compile_device: Device, f: F) -> R
88where
89    F: FnOnce() -> R,
90{
91    if compile_device == Device::Cpu {
92        f()
93    } else {
94        metal_compile_guard(session_device, f)
95    }
96}
97
98fn bucket_decode_mask_into(past_seq: usize, upper: usize, out: &mut Vec<f32>) {
99    out.resize(upper + 1, 0.0);
100    for (i, slot) in out.iter_mut().enumerate().take(upper + 1) {
101        *slot = if i < past_seq || i == upper { 1.0 } else { 0.0 };
102    }
103}
104
105impl CpCompiledEngine {
106    pub fn open(
107        model_dir: &Path,
108        store: &Qwen3TtsWeightStore,
109        cp: &CodePredictorConfig,
110        device: Device,
111    ) -> Result<Self> {
112        let mut wm = store.load_code_predictor_backbone()?;
113        let weights = remap_code_predictor_weights(&mut wm)?;
114        let compile_device = cp_compile_device(device);
115        let profiles = Qwen3TtsGraphProfiles::for_role(
116            model_dir,
117            Qwen3TtsGraphRole::CodePredictor,
118            compile_device,
119        );
120        let prefill = profiles.prefill;
121        let decode = profiles.decode;
122        let mut qwen3 = cp.to_qwen3_config();
123        qwen3.max_position_embeddings = qwen3.max_position_embeddings.min(CP_ROPE_TABLE_LEN);
124        let hidden = cp.hidden_size;
125        let kv_dim = qwen3.kv_proj_dim();
126        let n_layers = cp.num_hidden_layers;
127        let head_half = cp.head_dim / 2;
128        let inv_freq = crate::talker::rope::build_inv_freq(cp.head_dim, cp.rope_theta);
129        Ok(Self {
130            qwen3,
131            session_device: device,
132            compile_device,
133            hidden,
134            kv_dim,
135            n_layers,
136            head_half,
137            inv_freq,
138            weights: Arc::new(weights),
139            prefill_profile: prefill,
140            decode_profile: decode,
141            past_len: 0,
142            kv: KvCacheState {
143                past_len: 0,
144                layers_k: vec![Vec::new(); n_layers],
145                layers_v: vec![Vec::new(); n_layers],
146            },
147            prefill_cache: CompileCache::new(compile_device, 4),
148            decode_cache: BucketedCompileCache::power_of_two_ladder(
149                compile_device,
150                1,
151                CP_DECODE_BUCKET_MAX,
152            ),
153            prefill_scratch: vec![0f32; hidden * CP_PREFILL_SEQ],
154            decode_embed: vec![0f32; hidden],
155            hidden_row: vec![0f32; hidden],
156            last_raw_hidden: Vec::new(),
157            logits: vec![0f32; cp.vocab_size],
158            mask_buf: Vec::new(),
159        })
160    }
161
162    #[doc(hidden)]
163    pub fn last_raw_hidden(&self) -> &[f32] {
164        &self.last_raw_hidden
165    }
166
167    #[doc(hidden)]
168    pub fn export_kv_state(&self) -> (KvCacheState, usize) {
169        (self.kv.clone(), self.past_len)
170    }
171
172    #[doc(hidden)]
173    pub fn import_kv_state(&mut self, kv: KvCacheState, past_len: usize) {
174        self.kv = kv;
175        self.past_len = past_len;
176    }
177
178    pub fn warmup(&mut self, max_frames: usize) -> Result<()> {
179        let mut embeds = Array2::<f32>::zeros((CP_PREFILL_SEQ, self.hidden));
180        embeds[[0, 0]] = 1e-4;
181        self.reset_kv();
182        self.prefill(embeds.view())?;
183        if crate::synth_opts::lazy_talk_buckets()
184            && !crate::synth_opts::auto_precompile_horizon(max_frames)
185        {
186            let emb = vec![0f32; self.hidden];
187            let _ = self.decode_step(ArrayView1::from(&emb))?;
188        } else {
189            self.precompile_decode_buckets()?;
190        }
191        Ok(())
192    }
193
194    /// Warm CP decode buckets (past ≤ 16 per frame; ladder tops at 32).
195    fn precompile_decode_buckets(&mut self) -> Result<()> {
196        let keys: Vec<u64> = self
197            .decode_cache
198            .buckets()
199            .map(|r| r.end.saturating_sub(1))
200            .filter(|&k| k <= CP_DECODE_BUCKET_MAX)
201            .collect();
202        let opts = talker_compile_options(&self.decode_profile, self.compile_device);
203        for &key in &keys {
204            let weights = Arc::clone(&self.weights);
205            let qwen3 = self.qwen3.clone();
206            let decode_profile = self.decode_profile.clone();
207            cp_compile_guard(self.session_device, self.compile_device, || {
208                let _ = self.decode_cache.ensure_graph_with_params(
209                    key,
210                    move |upper| {
211                        cp_decode_graph_parts(&qwen3, weights.as_ref(), &decode_profile, upper)
212                            .expect("cp decode graph")
213                    },
214                    &opts,
215                );
216            });
217        }
218        Ok(())
219    }
220
221    pub fn reset_kv(&mut self) {
222        self.past_len = 0;
223        self.kv = KvCacheState {
224            past_len: 0,
225            layers_k: vec![Vec::new(); self.n_layers],
226            layers_v: vec![Vec::new(); self.n_layers],
227        };
228    }
229
230    pub fn prefill(&mut self, embeds: ArrayView2<f32>) -> Result<Array2<f32>> {
231        let (seq, h) = embeds.dim();
232        ensure!(h == self.hidden, "cp embed hidden mismatch");
233        ensure!(
234            seq <= CP_PREFILL_SEQ,
235            "cp prefill seq {seq} > {CP_PREFILL_SEQ}"
236        );
237        let flat: Vec<f32> = embeds.iter().copied().collect();
238        let positions: Vec<usize> = (0..seq).collect();
239        let rope_table_len = self.qwen3.max_position_embeddings;
240        let (rope_cos, rope_sin) = cp_prefill_rope_feeds(
241            &self.inv_freq,
242            &positions,
243            self.qwen3.head_dim,
244            rope_table_len,
245            self.head_half,
246        );
247        let opts = talker_compile_options(&self.prefill_profile, self.compile_device);
248        let key = ((1u64) << 32) | (seq as u64);
249        let qwen3 = self.qwen3.clone();
250        let weights = Arc::clone(&self.weights);
251        let profile = self.prefill_profile.clone();
252        let built = {
253            let mut wm = weight_map_from_cache(weights.as_ref())?;
254            if seq == crate::cp_frame::CP_PREFILL_TWO {
255                build_qwen3_tts_cp_prefill_two_built(
256                    &qwen3,
257                    &mut wm,
258                    &profile,
259                    Some(rope_cos),
260                    Some(rope_sin),
261                )?
262            } else {
263                crate::codec_frame::build_qwen3_tts_prefill_built(
264                    &qwen3,
265                    &mut wm,
266                    seq,
267                    &profile,
268                    Some(rope_cos),
269                    Some(rope_sin),
270                )?
271            }
272        };
273        let compiled = cp_compile_guard(self.session_device, self.compile_device, || {
274            compile_cache_ensure_built_with_options(&mut self.prefill_cache, key, built, &opts)
275        })?;
276        let outputs = compiled.run(&[("inputs_embeds", flat.as_slice())]);
277        let (hidden_out, kv) =
278            kv_from_prefill_outputs(outputs, 1, seq, self.kv_dim, self.n_layers)?;
279        self.kv = kv;
280        self.past_len = seq;
281        let rows = hidden_out.len() / self.hidden;
282        Ok(Array2::from_shape_vec((rows, self.hidden), hidden_out)?)
283    }
284
285    pub fn decode_step(&mut self, embed: ArrayView1<f32>) -> Result<Array1<f32>> {
286        ensure!(embed.len() == self.hidden);
287        self.decode_embed.copy_from_slice(embed.as_slice().unwrap());
288        cp_compile_guard(self.session_device, self.compile_device, || {
289            self.run_decode_step_inner()
290        })?;
291        Ok(Array1::from_vec(self.hidden_row.clone()))
292    }
293
294    fn run_decode_step_inner(&mut self) -> Result<()> {
295        let past_seq = self.past_len;
296        let pos = past_seq;
297        let (cos, sin) = rope_slice(&self.inv_freq, pos, self.qwen3.head_dim);
298        let upper = self
299            .decode_cache
300            .bucket_for(past_seq as u64)
301            .map(|idx| {
302                self.decode_cache
303                    .buckets()
304                    .nth(idx)
305                    .map(|r| (r.end - 1) as usize)
306                    .unwrap_or(past_seq)
307            })
308            .unwrap_or(past_seq);
309        bucket_decode_mask_into(past_seq, upper, &mut self.mask_buf);
310        let fixed = [
311            CacheRunInput {
312                name: "inputs_embeds",
313                data: self.decode_embed.as_slice(),
314                row_inner: None,
315            },
316            CacheRunInput {
317                name: "rope_cos",
318                data: &cos,
319                row_inner: None,
320            },
321            CacheRunInput {
322                name: "rope_sin",
323                data: &sin,
324                row_inner: None,
325            },
326            CacheRunInput {
327                name: "mask",
328                data: self.mask_buf.as_slice(),
329                row_inner: None,
330            },
331        ];
332        let opts = talker_compile_options(&self.decode_profile, self.compile_device);
333        let weights = Arc::clone(&self.weights);
334        let qwen3 = self.qwen3.clone();
335        let decode_profile = self.decode_profile.clone();
336        let (hidden_vec, new_k, new_v) = run_bucketed_kv_decode(
337            &mut self.decode_cache,
338            past_seq,
339            &self.kv,
340            self.kv_dim,
341            self.n_layers,
342            &fixed,
343            move |upper| {
344                cp_decode_graph_parts(&qwen3, weights.as_ref(), &decode_profile, upper)
345                    .expect("cp decode graph")
346            },
347            &opts,
348        )?;
349        commit_kv_layers(&mut self.kv.layers_k, &mut self.kv.layers_v, &new_k, &new_v);
350        self.kv.past_len = past_seq + 1;
351        self.past_len += 1;
352        self.last_raw_hidden = hidden_vec.clone();
353        bucket_decode_hidden_into(&hidden_vec, self.hidden, &mut self.hidden_row)?;
354        Ok(())
355    }
356
357    fn prefill_stacked(&mut self, seq: usize) -> Result<()> {
358        ensure!(seq <= CP_PREFILL_SEQ);
359        let flat_len = seq * self.hidden;
360        let flat = self.prefill_scratch[..flat_len].to_vec();
361        self.run_prefill_flat(&flat, seq)
362    }
363
364    fn run_prefill_flat(&mut self, flat: &[f32], seq: usize) -> Result<()> {
365        ensure!(flat.len() == seq * self.hidden);
366        let positions: Vec<usize> = (0..seq).collect();
367        let rope_table_len = self.qwen3.max_position_embeddings;
368        let (rope_cos, rope_sin) = cp_prefill_rope_feeds(
369            &self.inv_freq,
370            &positions,
371            self.qwen3.head_dim,
372            rope_table_len,
373            self.head_half,
374        );
375        let opts = talker_compile_options(&self.prefill_profile, self.compile_device);
376        let key = ((1u64) << 32) | (seq as u64);
377        let qwen3 = self.qwen3.clone();
378        let weights = Arc::clone(&self.weights);
379        let profile = self.prefill_profile.clone();
380        let built = {
381            let mut wm = weight_map_from_cache(weights.as_ref())?;
382            if seq == crate::cp_frame::CP_PREFILL_TWO {
383                build_qwen3_tts_cp_prefill_two_built(
384                    &qwen3,
385                    &mut wm,
386                    &profile,
387                    Some(rope_cos),
388                    Some(rope_sin),
389                )?
390            } else {
391                crate::codec_frame::build_qwen3_tts_prefill_built(
392                    &qwen3,
393                    &mut wm,
394                    seq,
395                    &profile,
396                    Some(rope_cos),
397                    Some(rope_sin),
398                )?
399            }
400        };
401        let compiled = cp_compile_guard(self.session_device, self.compile_device, || {
402            compile_cache_ensure_built_with_options(&mut self.prefill_cache, key, built, &opts)
403        })?;
404        let outputs = compiled.run(&[("inputs_embeds", flat)]);
405        let (hidden_out, kv) =
406            kv_from_prefill_outputs(outputs, 1, seq, self.kv_dim, self.n_layers)?;
407        self.kv = kv;
408        self.past_len = seq;
409        last_decode_hidden_into(&hidden_out, self.hidden, &mut self.hidden_row)?;
410        Ok(())
411    }
412
413    pub fn predict_groups(
414        &mut self,
415        talker_codec: &Array2<f32>,
416        group_embeds: &[Array2<f32>],
417        lm_heads: &[Array2<f32>],
418        talker_hidden: ArrayView1<f32>,
419        group0: u32,
420    ) -> Result<Vec<u32>> {
421        cp_compile_guard(self.session_device, self.compile_device, || {
422            self.predict_groups_inner(talker_codec, group_embeds, lm_heads, talker_hidden, group0)
423        })
424    }
425
426    fn predict_groups_inner(
427        &mut self,
428        talker_codec: &Array2<f32>,
429        group_embeds: &[Array2<f32>],
430        lm_heads: &[Array2<f32>],
431        talker_hidden: ArrayView1<f32>,
432        group0: u32,
433    ) -> Result<Vec<u32>> {
434        ensure!(talker_hidden.len() == self.hidden);
435        self.reset_kv();
436        let h = self.hidden;
437        self.prefill_scratch[..h].copy_from_slice(talker_hidden.as_slice().unwrap());
438        let e0 = talker_codec.row(group0 as usize);
439        self.prefill_scratch[h..h * 2].copy_from_slice(e0.as_slice().unwrap());
440        self.prefill_stacked(CP_PREFILL_SEQ)?;
441        let mut codes = vec![group0];
442        for step in 0..lm_heads.len() {
443            linear_logits_into(
444                ArrayView1::from(&self.hidden_row),
445                lm_heads[step].view(),
446                &mut self.logits,
447            )?;
448            let tok = sample_greedy(&self.logits);
449            codes.push(tok);
450            if step + 1 < lm_heads.len() {
451                let row = group_embeds[step].row(tok as usize);
452                self.decode_embed.copy_from_slice(row.as_slice().unwrap());
453                self.run_decode_step_inner()?;
454            }
455        }
456        Ok(codes)
457    }
458}