Skip to main content

rlx_qwen3_tts/code_predictor/
engine.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Code-predictor AR (compiled GPU/MLX/Metal when available, else CPU eager).
17
18use crate::code_predictor::compiled::CpCompiledEngine;
19use crate::code_predictor::eager::CpEagerModel;
20use crate::config::CodePredictorConfig;
21use crate::load::Qwen3TtsWeightStore;
22use anyhow::{Context, Result, ensure};
23use ndarray::{Array2, ArrayView1};
24use rlx_runtime::Device;
25
26fn cp_force_eager() -> bool {
27    std::env::var("RLX_QWEN3_TTS_CP_EAGER").ok().as_deref() == Some("1")
28}
29
30/// GPU sessions use compiled CP on Metal/CUDA/ROCm. CPU eager: `RLX_QWEN3_TTS_CP_EAGER=1`.
31pub fn cp_use_compiled_for_device(talker_device: Device) -> bool {
32    if cp_force_eager() {
33        return false;
34    }
35    if std::env::var("RLX_QWEN3_TTS_CP_COMPILED").ok().as_deref() == Some("1") {
36        return true;
37    }
38    if crate::gpu_pipeline::gpu_session_enabled(talker_device) {
39        return crate::gpu_pipeline::cp_use_gpu_on_device(talker_device);
40    }
41    talker_device != Device::Cpu && talker_device != Device::Metal
42}
43
44fn cp_execution_device(talker_device: Device) -> Device {
45    if !cp_use_compiled_for_device(talker_device) {
46        Device::Cpu
47    } else {
48        crate::compile_opts::cp_compile_device(talker_device)
49    }
50}
51
52enum CpBackend {
53    Eager(CpEagerModel),
54    Compiled(CpCompiledEngine),
55}
56
57pub struct CodePredictorEngine {
58    talker_device: Device,
59    cp_device: Device,
60    backend: CpBackend,
61    talker_codec: Array2<f32>,
62    talker_codec_flat: Vec<f32>,
63    group_embeds: Vec<Array2<f32>>,
64    group_embed_flat: Vec<Vec<f32>>,
65    lm_heads: Vec<Array2<f32>>,
66    lm_head_flat: Vec<Vec<f32>>,
67    lm_head_vocab: Vec<usize>,
68    hidden: usize,
69}
70
71impl CodePredictorEngine {
72    pub fn open(
73        store: &Qwen3TtsWeightStore,
74        cp: &CodePredictorConfig,
75        device: Device,
76    ) -> Result<Self> {
77        let talker_snap = store.tensor_snapshot(&["talker.model.codec_embedding.weight"])?;
78        let (tc_data, tc_shape) = talker_snap
79            .get("talker.model.codec_embedding.weight")
80            .context("talker codec_embedding")?;
81        let talker_codec_flat = tc_data.clone();
82        let talker_codec =
83            Array2::from_shape_vec((tc_shape[0], tc_shape[1]), talker_codec_flat.clone())?;
84
85        let mut group_embeds = Vec::with_capacity(cp.num_code_groups - 1);
86        let mut group_embed_flat = Vec::with_capacity(cp.num_code_groups - 1);
87        for i in 0..cp.num_code_groups - 1 {
88            let key = format!("talker.code_predictor.model.codec_embedding.{i}.weight");
89            let (data, shape) = store.tensor_snapshot(&[&key])?[&key].clone();
90            group_embeds.push(Array2::from_shape_vec((shape[0], shape[1]), data.clone())?);
91            group_embed_flat.push(data);
92        }
93        let mut lm_heads = Vec::with_capacity(cp.num_code_groups - 1);
94        let mut lm_head_flat = Vec::with_capacity(cp.num_code_groups - 1);
95        let mut lm_head_vocab = Vec::with_capacity(cp.num_code_groups - 1);
96        for i in 0..cp.num_code_groups - 1 {
97            let key = format!("talker.code_predictor.lm_head.{i}.weight");
98            let (data, shape) = store.tensor_snapshot(&[&key])?[&key].clone();
99            lm_head_vocab.push(shape[0]);
100            lm_head_flat.push(data.clone());
101            lm_heads.push(Array2::from_shape_vec((shape[0], shape[1]), data)?);
102        }
103
104        let cp_device = cp_execution_device(device);
105        let backend = if cp_use_compiled_for_device(device) {
106            CpBackend::Compiled(CpCompiledEngine::open(
107                store.model_dir(),
108                store,
109                cp,
110                cp_device,
111            )?)
112        } else {
113            CpBackend::Eager(CpEagerModel::open(store, cp)?)
114        };
115
116        Ok(Self {
117            talker_device: device,
118            cp_device,
119            backend,
120            talker_codec,
121            talker_codec_flat,
122            group_embeds,
123            group_embed_flat,
124            lm_heads,
125            lm_head_flat,
126            lm_head_vocab,
127            hidden: cp.hidden_size,
128        })
129    }
130
131    pub fn is_eager(&self) -> bool {
132        matches!(self.backend, CpBackend::Eager(_))
133    }
134
135    /// Talker codec embedding table flat (row-major `[codec_vocab, hidden]`).
136    /// Used by the speculative path to do cheap g0 group-embedding swaps when
137    /// synthesising verifier inputs from drafted g0 proposals.
138    pub fn talker_codec_flat(&self) -> (&[f32], usize) {
139        (&self.talker_codec_flat, self.hidden)
140    }
141
142    pub fn device(&self) -> Device {
143        self.cp_device
144    }
145
146    pub fn cp_backend_label(&self) -> String {
147        match &self.backend {
148            CpBackend::Eager(_) => "CPU eager".into(),
149            CpBackend::Compiled(_) if self.cp_device != self.talker_device => {
150                format!("compiled (CPU, talker {:?})", self.talker_device)
151            }
152            CpBackend::Compiled(_) => format!("compiled ({:?})", self.cp_device),
153        }
154    }
155
156    pub fn warmup(&mut self, max_frames: usize) -> Result<()> {
157        match &mut self.backend {
158            CpBackend::Eager(e) => {
159                let mut hidden = vec![0f32; self.hidden];
160                for (i, v) in hidden.iter_mut().enumerate() {
161                    *v = ((i % 17) as f32) * 1e-5;
162                }
163                let _ = e.predict_groups(
164                    &self.talker_codec,
165                    &self.group_embeds,
166                    &self.lm_heads,
167                    ArrayView1::from(&hidden),
168                    1995,
169                )?;
170                Ok(())
171            }
172            CpBackend::Compiled(c) => c.warmup(max_frames),
173        }
174    }
175
176    pub fn predict_groups_slice(&mut self, talker_hidden: &[f32], group0: u32) -> Result<Vec<u32>> {
177        self.predict_groups(ArrayView1::from(talker_hidden), group0)
178    }
179
180    /// CP predict + codec embed sum + pad (fused on eager).
181    pub fn predict_groups_fill_emb(
182        &mut self,
183        talker_hidden: &[f32],
184        group0: u32,
185        pad: &[f32],
186        codec_emb: &mut [f32],
187    ) -> Result<Vec<u32>> {
188        ensure!(codec_emb.len() == self.hidden);
189        match &mut self.backend {
190            CpBackend::Eager(e) => e.predict_groups_fill_emb_flat(
191                &self.talker_codec_flat,
192                &self.group_embed_flat,
193                &self.lm_head_flat,
194                &self.lm_head_vocab,
195                ArrayView1::from(talker_hidden),
196                group0,
197                pad,
198                codec_emb,
199                self.hidden,
200            ),
201            CpBackend::Compiled(c) => {
202                let groups = c.predict_groups(
203                    &self.talker_codec,
204                    &self.group_embeds,
205                    &self.lm_heads,
206                    ArrayView1::from(talker_hidden),
207                    group0,
208                )?;
209                codec_emb.fill(0.0);
210                self.sum_codec_groups_into(&groups, codec_emb)?;
211                for (j, v) in pad.iter().enumerate() {
212                    codec_emb[j] += *v;
213                }
214                Ok(groups)
215            }
216        }
217    }
218
219    pub fn predict_groups(
220        &mut self,
221        talker_hidden: ArrayView1<f32>,
222        group0: u32,
223    ) -> Result<Vec<u32>> {
224        ensure!(talker_hidden.len() == self.hidden);
225        match &mut self.backend {
226            CpBackend::Eager(e) => e.predict_groups(
227                &self.talker_codec,
228                &self.group_embeds,
229                &self.lm_heads,
230                talker_hidden,
231                group0,
232            ),
233            CpBackend::Compiled(c) => c.predict_groups(
234                &self.talker_codec,
235                &self.group_embeds,
236                &self.lm_heads,
237                talker_hidden,
238                group0,
239            ),
240        }
241    }
242
243    /// Sum codec group embeddings into `out` (group 0 = talker table).
244    pub fn sum_codec_groups_into(&self, groups: &[u32], out: &mut [f32]) -> Result<()> {
245        ensure!(out.len() == self.hidden, "codec emb buffer len mismatch");
246        out.fill(0.0);
247        for (gi, &tok) in groups.iter().enumerate() {
248            if gi == 0 {
249                ensure!(
250                    (tok as usize) < self.talker_codec.nrows(),
251                    "group0 token {tok} oob"
252                );
253                for (j, v) in self.talker_codec.row(tok as usize).iter().enumerate() {
254                    out[j] += *v;
255                }
256            } else {
257                let table = &self.group_embeds[gi - 1];
258                ensure!(
259                    (tok as usize) < table.nrows(),
260                    "token {tok} oob for group {gi}"
261                );
262                for (j, v) in table.row(tok as usize).iter().enumerate() {
263                    out[j] += *v;
264                }
265            }
266        }
267        Ok(())
268    }
269
270    pub fn sum_codec_groups(&self, groups: &[u32]) -> Result<Vec<f32>> {
271        let mut emb = vec![0f32; self.hidden];
272        self.sum_codec_groups_into(groups, &mut emb)?;
273        Ok(emb)
274    }
275
276    /// Per-step codec embeds for fused backbone (`groups[1..]` → `cp_step_embed_{i}`).
277    pub fn cp_step_embeds_from_groups(&self, groups: &[u32]) -> Result<Vec<Vec<f32>>> {
278        use crate::cp_frame::CP_DECODE_BACKBONE_STEPS;
279        ensure!(
280            groups.len() > CP_DECODE_BACKBONE_STEPS,
281            "groups len {} < {}",
282            groups.len(),
283            1 + CP_DECODE_BACKBONE_STEPS
284        );
285        let mut out = Vec::with_capacity(CP_DECODE_BACKBONE_STEPS);
286        for step in 0..CP_DECODE_BACKBONE_STEPS {
287            out.push(self.codec_embed_row(step + 1, groups[step + 1])?);
288        }
289        Ok(out)
290    }
291
292    pub fn codec_embed_row(&self, group_idx: usize, token: u32) -> Result<Vec<f32>> {
293        if group_idx == 0 {
294            ensure!(
295                (token as usize) < self.talker_codec.nrows(),
296                "group0 token {token} oob"
297            );
298            return Ok(self.talker_codec.row(token as usize).to_vec());
299        }
300        let gi = group_idx - 1;
301        ensure!(gi < self.group_embeds.len(), "group_idx {group_idx} oob");
302        let table = &self.group_embeds[gi];
303        ensure!(
304            (token as usize) < table.nrows(),
305            "token {token} oob for group {group_idx}"
306        );
307        Ok(table.row(token as usize).to_vec())
308    }
309}