Skip to main content

rlx_voxtral_tts/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! End-to-end TTS runner — native Rust only.
17
18use crate::acoustic::AcousticTransformer;
19use crate::backbone::NativeTtsEngine;
20use crate::bench::VoxtralTtsBenchReport;
21use crate::codec::decoder::CodecDecoder;
22use crate::config::VoxtralTtsConfig;
23use crate::generation::GenerationConfig;
24use crate::load::{PREFIX_CODEC, VoxtralTtsWeightStore};
25use crate::options::{VoxtralTtsOptions, VoxtralTtsRunnerBuilder};
26use crate::prompt_tokens::load_prompt_tokens;
27use crate::speech_tokenizer::SpeechTokenizer;
28use crate::tokens::PRESET_VOICES;
29use crate::voice::VoiceEmbedding;
30use crate::voice_clone::{VoiceCloneSupport, encode_reference_wav, voice_clone_support};
31use anyhow::{Context, Result, bail};
32use rlx_runtime::Device;
33use std::path::Path;
34
35pub struct VoxtralTtsRunner {
36    cfg: VoxtralTtsConfig,
37    store: VoxtralTtsWeightStore,
38    options: VoxtralTtsOptions,
39    codec: CodecDecoder,
40    #[allow(dead_code)]
41    acoustic: AcousticTransformer,
42    native: Option<NativeTtsEngine>,
43}
44
45impl VoxtralTtsRunner {
46    pub fn builder() -> VoxtralTtsRunnerBuilder {
47        VoxtralTtsRunnerBuilder::default()
48    }
49
50    pub fn open(model_dir: &Path) -> Result<Self> {
51        Self::open_with_options(model_dir, VoxtralTtsOptions::default())
52    }
53
54    pub fn open_with_options(model_dir: &Path, options: VoxtralTtsOptions) -> Result<Self> {
55        options.validate()?;
56        let store = VoxtralTtsWeightStore::open(model_dir)?;
57        let cfg = VoxtralTtsConfig::from_model_dir(store.model_dir())?;
58        let codec_tensors = store.tensor_snapshot(PREFIX_CODEC)?;
59        let codec =
60            CodecDecoder::from_tensors(PREFIX_CODEC, &codec_tensors, &cfg.audio_config.codec_args)?;
61        let acoustic_tensors = store.tensor_snapshot(crate::load::PREFIX_ACOUSTIC)?;
62        let acoustic = AcousticTransformer::from_tensors(
63            crate::load::PREFIX_ACOUSTIC,
64            &acoustic_tensors,
65            &cfg.audio_config.audio_model_args.acoustic_transformer_args,
66            cfg.audio_config.audio_model_args.n_acoustic_codebook,
67            cfg.audio_config.audio_model_args.semantic_codebook_size,
68        )?;
69        Ok(Self {
70            cfg,
71            store,
72            options,
73            codec,
74            acoustic,
75            native: None,
76        })
77    }
78
79    pub fn config(&self) -> &VoxtralTtsConfig {
80        &self.cfg
81    }
82
83    pub fn device(&self) -> Device {
84        self.options.device
85    }
86
87    pub fn options(&self) -> &VoxtralTtsOptions {
88        &self.options
89    }
90
91    pub fn model_dir(&self) -> &Path {
92        self.store.model_dir()
93    }
94
95    pub fn decode_codes_to_pcm(&self, codes: &[u32], n_frames: usize) -> Result<Vec<f32>> {
96        self.codec.decode_codes(codes, n_frames)
97    }
98
99    pub fn decode_codes_file(&self, codes_path: &Path, out_wav: &Path) -> Result<()> {
100        let (codes, n_frames) = parse_codes_file(codes_path)?;
101        let pcm = self.decode_codes_to_pcm(&codes, n_frames)?;
102        write_wav_mono(
103            out_wav,
104            &pcm,
105            self.cfg.audio_config.codec_args.sampling_rate as u32,
106        )
107    }
108
109    pub fn voice_clone_support(&self) -> VoiceCloneSupport {
110        voice_clone_support(&self.store)
111    }
112
113    /// Native path with an explicit voice embedding (preset or cloned).
114    pub fn synthesize_native_with_voice(
115        &mut self,
116        prompt_tokens: &[u32],
117        voice_emb: &VoiceEmbedding,
118        out_wav: &Path,
119        gen_cfg: &GenerationConfig,
120    ) -> Result<()> {
121        if self.native.is_none() {
122            self.native = Some(NativeTtsEngine::open(
123                &self.store,
124                &self.cfg,
125                &self.options,
126            )?);
127        }
128        let engine = self.native.as_mut().unwrap();
129        let pcm = engine.synthesize(prompt_tokens, voice_emb, gen_cfg)?;
130        write_wav_mono(
131            out_wav,
132            &pcm,
133            self.cfg.audio_config.codec_args.sampling_rate as u32,
134        )
135    }
136
137    /// Encode reference audio then synthesize (requires injected encoder weights).
138    pub fn synthesize_cloned_from_wav(
139        &mut self,
140        prompt_tokens: &[u32],
141        reference_wav: &Path,
142        out_wav: &Path,
143        gen_cfg: &GenerationConfig,
144    ) -> Result<()> {
145        let voice_emb = encode_reference_wav(&self.store, &self.cfg, reference_wav, "cloned")?;
146        self.synthesize_native_with_voice(prompt_tokens, &voice_emb, out_wav, gen_cfg)
147    }
148
149    /// Encode reference audio, tokenize for its frame count, then synthesize.
150    pub fn synthesize_cloned_with_text(
151        &mut self,
152        text: &str,
153        reference_wav: &Path,
154        out_wav: &Path,
155        gen_cfg: &GenerationConfig,
156    ) -> Result<()> {
157        let voice_emb = encode_reference_wav(&self.store, &self.cfg, reference_wav, "cloned")?;
158        let tok = SpeechTokenizer::from_model_dir(self.model_dir())?;
159        let prompt_tokens = tok.encode_speech_with_n_audio(text, voice_emb.n_tokens as u32)?;
160        self.synthesize_native_with_voice(&prompt_tokens, &voice_emb, out_wav, gen_cfg)
161    }
162
163    /// Native path: compiled LM + acoustic (or eager fallbacks) + codec decode on CPU.
164    pub fn synthesize_native(
165        &mut self,
166        prompt_tokens: &[u32],
167        voice: &str,
168        out_wav: &Path,
169        gen_cfg: &GenerationConfig,
170    ) -> Result<()> {
171        let hidden = self.cfg.text_config.hidden_size;
172        let voice_emb = resolve_preset_voice(self.model_dir(), voice, hidden)?;
173        self.synthesize_native_with_voice(prompt_tokens, &voice_emb, out_wav, gen_cfg)
174    }
175
176    pub fn synthesize_native_from_token_file(
177        &mut self,
178        prompt_tokens_path: &Path,
179        voice: &str,
180        out_wav: &Path,
181        gen_cfg: &GenerationConfig,
182    ) -> Result<()> {
183        let tokens = load_prompt_tokens(prompt_tokens_path)?;
184        self.synthesize_native(&tokens, voice, out_wav, gen_cfg)
185    }
186
187    pub fn synthesize_cloned_from_token_file(
188        &mut self,
189        prompt_tokens_path: &Path,
190        reference_wav: &Path,
191        out_wav: &Path,
192        gen_cfg: &GenerationConfig,
193    ) -> Result<()> {
194        let tokens = load_prompt_tokens(prompt_tokens_path)?;
195        self.synthesize_cloned_from_wav(&tokens, reference_wav, out_wav, gen_cfg)
196    }
197
198    pub fn synthesize_native_with_embedding_file(
199        &mut self,
200        prompt_tokens_path: &Path,
201        voice_embedding: &Path,
202        out_wav: &Path,
203        gen_cfg: &GenerationConfig,
204    ) -> Result<()> {
205        let tokens = load_prompt_tokens(prompt_tokens_path)?;
206        let hidden = self.cfg.text_config.hidden_size;
207        let voice_emb = VoiceEmbedding::load_f32(voice_embedding, "custom", hidden)?;
208        self.synthesize_native_with_voice(&tokens, &voice_emb, out_wav, gen_cfg)
209    }
210
211    pub fn encode_reference_to_file(
212        &self,
213        reference_wav: &Path,
214        out_f32: &Path,
215        voice_name: &str,
216    ) -> Result<VoiceEmbedding> {
217        crate::voice_clone::encode_reference_wav_to_file(
218            &self.store,
219            &self.cfg,
220            reference_wav,
221            out_f32,
222            voice_name,
223        )
224    }
225
226    /// Timed native synthesis on a fixed prompt (same weights for all option sets).
227    pub fn bench_native_profiled(
228        &mut self,
229        prompt_tokens: &[u32],
230        voice: &str,
231        gen_cfg: &GenerationConfig,
232        options: &VoxtralTtsOptions,
233    ) -> Result<VoxtralTtsBenchReport> {
234        options.validate()?;
235        let hidden = self.cfg.text_config.hidden_size;
236        let voice_emb = resolve_preset_voice(self.model_dir(), voice, hidden)?;
237        let mut engine = NativeTtsEngine::open(&self.store, &self.cfg, options)?;
238        let (_, report) = engine.synthesize_profiled(prompt_tokens, &voice_emb, gen_cfg)?;
239        Ok(report)
240    }
241
242    pub fn synthesize_native_codes(
243        &mut self,
244        prompt_tokens: &[u32],
245        voice: &str,
246        gen_cfg: &GenerationConfig,
247    ) -> Result<Vec<u32>> {
248        let hidden = self.cfg.text_config.hidden_size;
249        let voice_emb = resolve_preset_voice(self.model_dir(), voice, hidden)?;
250        if self.native.is_none() {
251            self.native = Some(NativeTtsEngine::open(
252                &self.store,
253                &self.cfg,
254                &self.options,
255            )?);
256        }
257        let engine = self.native.as_mut().unwrap();
258        engine.synthesize_codes(prompt_tokens, &voice_emb, gen_cfg)
259    }
260}
261
262fn resolve_preset_voice(model_dir: &Path, voice: &str, hidden: usize) -> Result<VoiceEmbedding> {
263    if PRESET_VOICES.contains(&voice) {
264        return VoiceEmbedding::load(model_dir, voice, hidden);
265    }
266    bail!(
267        "unknown preset voice {voice:?}; expected one of {} (or use --reference-wav / --voice-embedding)",
268        PRESET_VOICES.join(", ")
269    )
270}
271
272pub fn parse_codes_file(path: &Path) -> Result<(Vec<u32>, usize)> {
273    let text = std::fs::read_to_string(path).with_context(|| format!("read {}", path.display()))?;
274    let mut lines = text.lines();
275    let n_frames: usize = lines
276        .next()
277        .ok_or_else(|| anyhow::anyhow!("empty codes file"))?
278        .parse()
279        .context("parse frame count")?;
280    let body = lines.next().unwrap_or_default();
281    let codes: Vec<u32> = body
282        .split_whitespace()
283        .map(|s| s.parse().context("parse code"))
284        .collect::<Result<_>>()?;
285    Ok((codes, n_frames))
286}
287
288pub fn write_wav_mono(path: &Path, pcm: &[f32], sample_rate: u32) -> Result<()> {
289    let mut bytes = Vec::with_capacity(44 + pcm.len() * 2);
290    bytes.extend_from_slice(b"RIFF");
291    let data_bytes = (pcm.len() * 2 + 36) as u32;
292    bytes.extend_from_slice(&data_bytes.to_le_bytes());
293    bytes.extend_from_slice(b"WAVEfmt ");
294    bytes.extend_from_slice(&16u32.to_le_bytes());
295    bytes.extend_from_slice(&1u16.to_le_bytes()); // PCM
296    bytes.extend_from_slice(&1u16.to_le_bytes()); // mono
297    bytes.extend_from_slice(&sample_rate.to_le_bytes());
298    bytes.extend_from_slice(&(sample_rate * 2).to_le_bytes());
299    bytes.extend_from_slice(&2u16.to_le_bytes());
300    bytes.extend_from_slice(&16u16.to_le_bytes());
301    bytes.extend_from_slice(b"data");
302    bytes.extend_from_slice(&((pcm.len() * 2) as u32).to_le_bytes());
303    for &s in pcm {
304        let v = (s.clamp(-1.0, 1.0) * i16::MAX as f32) as i16;
305        bytes.extend_from_slice(&v.to_le_bytes());
306    }
307    std::fs::write(path, bytes).with_context(|| format!("write {}", path.display()))?;
308    Ok(())
309}