Skip to main content

rlx_voxtral/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Voxtral end-to-end runner — mel → audio encoder → projector → fused Llama decode.
17
18use crate::audio::{MelSpectrogram, pcm_to_mel_and_prompt};
19use crate::config::VoxtralConfig;
20use crate::embed::{argmax_token, fuse_inputs_embeds, validate_prompt_audio_match};
21use crate::encoder::build_voxtral_encoder_built;
22use crate::lm_flow::{build_voxtral_decode_built, build_voxtral_prefill_built};
23use crate::load::{VoxtralWeightStore, resolve_model_dir};
24use crate::projector::build_voxtral_projector_built;
25use crate::weights::VoxtralWeightPrefix;
26use anyhow::{Context, Result, ensure};
27use rlx_core::flow_util::compile_built;
28use rlx_core::validate_standard_device;
29use rlx_llama32::rope::{resolve_inv_freq, rope_slice};
30use rlx_runtime::Device;
31use std::path::{Path, PathBuf};
32
33#[derive(Debug, Clone, Default)]
34pub struct VoxtralRunnerBuilder {
35    weights: Option<PathBuf>,
36    config_path: Option<PathBuf>,
37    config: Option<VoxtralConfig>,
38    device: Option<Device>,
39    max_new_tokens: usize,
40}
41
42impl VoxtralRunnerBuilder {
43    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
44        self.weights = Some(path.into());
45        self
46    }
47
48    pub fn config_path(mut self, path: impl Into<PathBuf>) -> Self {
49        self.config_path = Some(path.into());
50        self
51    }
52
53    pub fn config(mut self, cfg: VoxtralConfig) -> Self {
54        self.config = Some(cfg);
55        self
56    }
57
58    pub fn device(mut self, d: Device) -> Self {
59        self.device = Some(d);
60        self
61    }
62
63    pub fn max_new_tokens(mut self, n: usize) -> Self {
64        self.max_new_tokens = n;
65        self
66    }
67
68    pub fn build(self) -> Result<VoxtralRunner> {
69        let weights_path = self
70            .weights
71            .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
72        let model_dir = resolve_model_dir(&weights_path)?;
73        let cfg_path = self
74            .config_path
75            .clone()
76            .unwrap_or_else(|| model_dir.join("config.json"));
77        let cfg = match self.config {
78            Some(c) => c,
79            None => VoxtralConfig::from_file(&cfg_path)
80                .with_context(|| format!("reading {cfg_path:?}"))?,
81        };
82        cfg.validate()?;
83        let device = self.device.unwrap_or(Device::Cpu);
84        validate_standard_device("voxtral", device)?;
85        let max_new_tokens = if self.max_new_tokens == 0 {
86            256
87        } else {
88            self.max_new_tokens
89        };
90        let weight_store = VoxtralWeightStore::open(&weights_path)?;
91
92        Ok(VoxtralRunner {
93            cfg,
94            device,
95            max_new_tokens,
96            weight_store,
97        })
98    }
99}
100
101pub struct VoxtralRunner {
102    cfg: VoxtralConfig,
103    device: Device,
104    max_new_tokens: usize,
105    weight_store: VoxtralWeightStore,
106}
107
108impl VoxtralRunner {
109    pub fn builder() -> VoxtralRunnerBuilder {
110        VoxtralRunnerBuilder::default()
111    }
112
113    pub fn config(&self) -> &VoxtralConfig {
114        &self.cfg
115    }
116
117    pub fn model_dir(&self) -> &Path {
118        self.weight_store.model_dir()
119    }
120
121    pub fn encode_audio(&self, mel: &MelSpectrogram) -> Result<Vec<f32>> {
122        let batch = 1;
123        let mel_frames = mel.n_frames;
124        let enc_seq = self.cfg.audio_config.encoder_seq_len(mel_frames);
125        ensure!(
126            enc_seq.is_multiple_of(4),
127            "encoder seq {enc_seq} not divisible by 4 — pad mel to a compatible length"
128        );
129
130        let mut wm = self.weight_store.load_audio_weights()?;
131        let enc_built =
132            build_voxtral_encoder_built(&self.cfg.audio_config, &mut wm, batch, mel_frames)?;
133        let enc_params = enc_built.params().clone();
134        let mut enc = compile_built(enc_built, self.device)?;
135        for (n, d) in &enc_params {
136            enc.set_param(n, d);
137        }
138        let enc_out = enc
139            .run(&[("mel", mel.data.as_slice())])
140            .into_iter()
141            .next()
142            .context("encoder output")?;
143        drop(wm);
144
145        let mut wm2 = self.weight_store.load_projector_weights()?;
146        let proj_built = build_voxtral_projector_built(&self.cfg, &mut wm2, batch, enc_seq)?;
147        let proj_params = proj_built.params().clone();
148        let mut proj = compile_built(proj_built, self.device)?;
149        for (n, d) in &proj_params {
150            proj.set_param(n, d);
151        }
152        let audio_embeds = proj
153            .run(&[("encoder_hidden", &enc_out)])
154            .into_iter()
155            .next()
156            .context("projector output")?;
157        Ok(audio_embeds)
158    }
159
160    pub fn generate(&self, prompt_ids: &[u32], mel: &MelSpectrogram) -> Result<Vec<u32>> {
161        let batch = 1;
162        let audio_embeds = self.encode_audio(mel)?;
163        let h = self.cfg.text_config.hidden_size;
164        let n_audio = audio_embeds.len() / h;
165        validate_prompt_audio_match(&self.cfg, prompt_ids, n_audio)?;
166
167        let embed_wm = self
168            .weight_store
169            .load_keys(&[VoxtralWeightPrefix::lm_embed_tokens()])?;
170        let inputs_embeds = fuse_inputs_embeds(&self.cfg, &embed_wm, prompt_ids, &audio_embeds)?;
171        drop(embed_wm);
172        let seq = prompt_ids.len();
173
174        let mut wm = self.weight_store.load_language_model_weights()?;
175        let prefill_built =
176            build_voxtral_prefill_built(&self.cfg, &mut wm, batch, seq, true, true)?;
177        let prefill_params = prefill_built.params().clone();
178        let mut prefill = compile_built(prefill_built, self.device)?;
179        for (n, d) in &prefill_params {
180            prefill.set_param(n, d);
181        }
182
183        let pre_in = [("inputs_embeds", inputs_embeds.as_slice())];
184        let outs = prefill.run(&pre_in);
185        let logits = &outs[0];
186        let vocab = self.cfg.text_config.vocab_size;
187        let mut tokens: Vec<u32> = prompt_ids.to_vec();
188        ensure!(
189            logits.len() == batch * vocab,
190            "expected last-token logits [{batch}, {vocab}], got {}",
191            logits.len()
192        );
193        let mut next = argmax_token(logits);
194
195        let kv_start = 1usize;
196        let mut kv_caches: Vec<Vec<f32>> = outs[kv_start..].to_vec();
197        drop(prefill);
198
199        let layers = self.cfg.text_config.num_hidden_layers;
200        let key_past: Vec<String> = (0..layers)
201            .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
202            .collect();
203
204        let llama = self.cfg.llama_config();
205        let inv_freq = resolve_inv_freq(llama, None);
206
207        for past_len in seq..seq.saturating_add(self.max_new_tokens) {
208            if next == 0 {
209                break;
210            }
211            tokens.push(next);
212
213            let mut wm_dec = self.weight_store.load_language_model_weights()?;
214            let dec_built =
215                build_voxtral_decode_built(&self.cfg, &mut wm_dec, batch, past_len, false)?;
216            let dec_params = dec_built.params().clone();
217            let mut dec = compile_built(dec_built, self.device)?;
218            for (n, d) in &dec_params {
219                dec.set_param(n, d);
220            }
221            drop(wm_dec);
222
223            let token_f = [next as f32];
224            let (cos, sin) = rope_slice(&inv_freq, past_len);
225            let mut dec_in: Vec<(&str, &[f32])> = vec![
226                ("input_ids", &token_f),
227                ("rope_cos", cos.as_slice()),
228                ("rope_sin", sin.as_slice()),
229            ];
230            for i in 0..layers {
231                dec_in.push((key_past[2 * i].as_str(), kv_caches[2 * i].as_slice()));
232                dec_in.push((
233                    key_past[2 * i + 1].as_str(),
234                    kv_caches[2 * i + 1].as_slice(),
235                ));
236            }
237            let step_out = dec.run(&dec_in);
238            next = argmax_token(&step_out[0]);
239            kv_caches = step_out[1..].to_vec();
240        }
241
242        Ok(tokens)
243    }
244
245    pub fn transcribe_wav(&self, wav: &Path, language: Option<&str>) -> Result<Vec<u32>> {
246        let (mel, prompt) = pcm_to_mel_and_prompt(self.model_dir(), Some(wav), language)?;
247        self.generate(&prompt, &mel)
248    }
249}