1use crate::audio::{MelSpectrogram, pcm_to_mel_and_prompt};
19use crate::config::VoxtralConfig;
20use crate::embed::{argmax_token, fuse_inputs_embeds, validate_prompt_audio_match};
21use crate::encoder::build_voxtral_encoder_built;
22use crate::lm_flow::{build_voxtral_decode_built, build_voxtral_prefill_built};
23use crate::load::{VoxtralWeightStore, resolve_model_dir};
24use crate::projector::build_voxtral_projector_built;
25use crate::weights::VoxtralWeightPrefix;
26use anyhow::{Context, Result, ensure};
27use rlx_core::flow_util::compile_built;
28use rlx_core::validate_standard_device;
29use rlx_llama32::rope::{resolve_inv_freq, rope_slice};
30use rlx_runtime::Device;
31use std::path::{Path, PathBuf};
32
33#[derive(Debug, Clone, Default)]
34pub struct VoxtralRunnerBuilder {
35 weights: Option<PathBuf>,
36 config_path: Option<PathBuf>,
37 config: Option<VoxtralConfig>,
38 device: Option<Device>,
39 max_new_tokens: usize,
40}
41
42impl VoxtralRunnerBuilder {
43 pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
44 self.weights = Some(path.into());
45 self
46 }
47
48 pub fn config_path(mut self, path: impl Into<PathBuf>) -> Self {
49 self.config_path = Some(path.into());
50 self
51 }
52
53 pub fn config(mut self, cfg: VoxtralConfig) -> Self {
54 self.config = Some(cfg);
55 self
56 }
57
58 pub fn device(mut self, d: Device) -> Self {
59 self.device = Some(d);
60 self
61 }
62
63 pub fn max_new_tokens(mut self, n: usize) -> Self {
64 self.max_new_tokens = n;
65 self
66 }
67
68 pub fn build(self) -> Result<VoxtralRunner> {
69 let weights_path = self
70 .weights
71 .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
72 let model_dir = resolve_model_dir(&weights_path)?;
73 let cfg_path = self
74 .config_path
75 .clone()
76 .unwrap_or_else(|| model_dir.join("config.json"));
77 let cfg = match self.config {
78 Some(c) => c,
79 None => VoxtralConfig::from_file(&cfg_path)
80 .with_context(|| format!("reading {cfg_path:?}"))?,
81 };
82 cfg.validate()?;
83 let device = self.device.unwrap_or(Device::Cpu);
84 validate_standard_device("voxtral", device)?;
85 let max_new_tokens = if self.max_new_tokens == 0 {
86 256
87 } else {
88 self.max_new_tokens
89 };
90 let weight_store = VoxtralWeightStore::open(&weights_path)?;
91
92 Ok(VoxtralRunner {
93 cfg,
94 device,
95 max_new_tokens,
96 weight_store,
97 })
98 }
99}
100
101pub struct VoxtralRunner {
102 cfg: VoxtralConfig,
103 device: Device,
104 max_new_tokens: usize,
105 weight_store: VoxtralWeightStore,
106}
107
108impl VoxtralRunner {
109 pub fn builder() -> VoxtralRunnerBuilder {
110 VoxtralRunnerBuilder::default()
111 }
112
113 pub fn config(&self) -> &VoxtralConfig {
114 &self.cfg
115 }
116
117 pub fn model_dir(&self) -> &Path {
118 self.weight_store.model_dir()
119 }
120
121 pub fn encode_audio(&self, mel: &MelSpectrogram) -> Result<Vec<f32>> {
122 let batch = 1;
123 let mel_frames = mel.n_frames;
124 let enc_seq = self.cfg.audio_config.encoder_seq_len(mel_frames);
125 ensure!(
126 enc_seq.is_multiple_of(4),
127 "encoder seq {enc_seq} not divisible by 4 — pad mel to a compatible length"
128 );
129
130 let mut wm = self.weight_store.load_audio_weights()?;
131 let enc_built =
132 build_voxtral_encoder_built(&self.cfg.audio_config, &mut wm, batch, mel_frames)?;
133 let enc_params = enc_built.params().clone();
134 let mut enc = compile_built(enc_built, self.device)?;
135 for (n, d) in &enc_params {
136 enc.set_param(n, d);
137 }
138 let enc_out = enc
139 .run(&[("mel", mel.data.as_slice())])
140 .into_iter()
141 .next()
142 .context("encoder output")?;
143 drop(wm);
144
145 let mut wm2 = self.weight_store.load_projector_weights()?;
146 let proj_built = build_voxtral_projector_built(&self.cfg, &mut wm2, batch, enc_seq)?;
147 let proj_params = proj_built.params().clone();
148 let mut proj = compile_built(proj_built, self.device)?;
149 for (n, d) in &proj_params {
150 proj.set_param(n, d);
151 }
152 let audio_embeds = proj
153 .run(&[("encoder_hidden", &enc_out)])
154 .into_iter()
155 .next()
156 .context("projector output")?;
157 Ok(audio_embeds)
158 }
159
160 pub fn generate(&self, prompt_ids: &[u32], mel: &MelSpectrogram) -> Result<Vec<u32>> {
161 let batch = 1;
162 let audio_embeds = self.encode_audio(mel)?;
163 let h = self.cfg.text_config.hidden_size;
164 let n_audio = audio_embeds.len() / h;
165 validate_prompt_audio_match(&self.cfg, prompt_ids, n_audio)?;
166
167 let embed_wm = self
168 .weight_store
169 .load_keys(&[VoxtralWeightPrefix::lm_embed_tokens()])?;
170 let inputs_embeds = fuse_inputs_embeds(&self.cfg, &embed_wm, prompt_ids, &audio_embeds)?;
171 drop(embed_wm);
172 let seq = prompt_ids.len();
173
174 let mut wm = self.weight_store.load_language_model_weights()?;
175 let prefill_built =
176 build_voxtral_prefill_built(&self.cfg, &mut wm, batch, seq, true, true)?;
177 let prefill_params = prefill_built.params().clone();
178 let mut prefill = compile_built(prefill_built, self.device)?;
179 for (n, d) in &prefill_params {
180 prefill.set_param(n, d);
181 }
182
183 let pre_in = [("inputs_embeds", inputs_embeds.as_slice())];
184 let outs = prefill.run(&pre_in);
185 let logits = &outs[0];
186 let vocab = self.cfg.text_config.vocab_size;
187 let mut tokens: Vec<u32> = prompt_ids.to_vec();
188 ensure!(
189 logits.len() == batch * vocab,
190 "expected last-token logits [{batch}, {vocab}], got {}",
191 logits.len()
192 );
193 let mut next = argmax_token(logits);
194
195 let kv_start = 1usize;
196 let mut kv_caches: Vec<Vec<f32>> = outs[kv_start..].to_vec();
197 drop(prefill);
198
199 let layers = self.cfg.text_config.num_hidden_layers;
200 let key_past: Vec<String> = (0..layers)
201 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
202 .collect();
203
204 let llama = self.cfg.llama_config();
205 let inv_freq = resolve_inv_freq(llama, None);
206
207 for past_len in seq..seq.saturating_add(self.max_new_tokens) {
208 if next == 0 {
209 break;
210 }
211 tokens.push(next);
212
213 let mut wm_dec = self.weight_store.load_language_model_weights()?;
214 let dec_built =
215 build_voxtral_decode_built(&self.cfg, &mut wm_dec, batch, past_len, false)?;
216 let dec_params = dec_built.params().clone();
217 let mut dec = compile_built(dec_built, self.device)?;
218 for (n, d) in &dec_params {
219 dec.set_param(n, d);
220 }
221 drop(wm_dec);
222
223 let token_f = [next as f32];
224 let (cos, sin) = rope_slice(&inv_freq, past_len);
225 let mut dec_in: Vec<(&str, &[f32])> = vec![
226 ("input_ids", &token_f),
227 ("rope_cos", cos.as_slice()),
228 ("rope_sin", sin.as_slice()),
229 ];
230 for i in 0..layers {
231 dec_in.push((key_past[2 * i].as_str(), kv_caches[2 * i].as_slice()));
232 dec_in.push((
233 key_past[2 * i + 1].as_str(),
234 kv_caches[2 * i + 1].as_slice(),
235 ));
236 }
237 let step_out = dec.run(&dec_in);
238 next = argmax_token(&step_out[0]);
239 kv_caches = step_out[1..].to_vec();
240 }
241
242 Ok(tokens)
243 }
244
245 pub fn transcribe_wav(&self, wav: &Path, language: Option<&str>) -> Result<Vec<u32>> {
246 let (mel, prompt) = pcm_to_mel_and_prompt(self.model_dir(), Some(wav), language)?;
247 self.generate(&prompt, &mel)
248 }
249}