1use crate::acoustic::AcousticTransformer;
19use crate::backbone::NativeTtsEngine;
20use crate::bench::VoxtralTtsBenchReport;
21use crate::codec::decoder::CodecDecoder;
22use crate::config::VoxtralTtsConfig;
23use crate::generation::GenerationConfig;
24use crate::load::{PREFIX_CODEC, VoxtralTtsWeightStore};
25use crate::options::{VoxtralTtsOptions, VoxtralTtsRunnerBuilder};
26use crate::prompt_tokens::load_prompt_tokens;
27use crate::speech_tokenizer::SpeechTokenizer;
28use crate::tokens::PRESET_VOICES;
29use crate::voice::VoiceEmbedding;
30use crate::voice_clone::{VoiceCloneSupport, encode_reference_wav, voice_clone_support};
31use anyhow::{Context, Result, bail};
32use rlx_runtime::Device;
33use std::path::Path;
34
35pub struct VoxtralTtsRunner {
36 cfg: VoxtralTtsConfig,
37 store: VoxtralTtsWeightStore,
38 options: VoxtralTtsOptions,
39 codec: CodecDecoder,
40 #[allow(dead_code)]
41 acoustic: AcousticTransformer,
42 native: Option<NativeTtsEngine>,
43}
44
45impl VoxtralTtsRunner {
46 pub fn builder() -> VoxtralTtsRunnerBuilder {
47 VoxtralTtsRunnerBuilder::default()
48 }
49
50 pub fn open(model_dir: &Path) -> Result<Self> {
51 Self::open_with_options(model_dir, VoxtralTtsOptions::default())
52 }
53
54 pub fn open_with_options(model_dir: &Path, options: VoxtralTtsOptions) -> Result<Self> {
55 options.validate()?;
56 let store = VoxtralTtsWeightStore::open(model_dir)?;
57 let cfg = VoxtralTtsConfig::from_model_dir(store.model_dir())?;
58 let codec_tensors = store.tensor_snapshot(PREFIX_CODEC)?;
59 let codec =
60 CodecDecoder::from_tensors(PREFIX_CODEC, &codec_tensors, &cfg.audio_config.codec_args)?;
61 let acoustic_tensors = store.tensor_snapshot(crate::load::PREFIX_ACOUSTIC)?;
62 let acoustic = AcousticTransformer::from_tensors(
63 crate::load::PREFIX_ACOUSTIC,
64 &acoustic_tensors,
65 &cfg.audio_config.audio_model_args.acoustic_transformer_args,
66 cfg.audio_config.audio_model_args.n_acoustic_codebook,
67 cfg.audio_config.audio_model_args.semantic_codebook_size,
68 )?;
69 Ok(Self {
70 cfg,
71 store,
72 options,
73 codec,
74 acoustic,
75 native: None,
76 })
77 }
78
79 pub fn config(&self) -> &VoxtralTtsConfig {
80 &self.cfg
81 }
82
83 pub fn device(&self) -> Device {
84 self.options.device
85 }
86
87 pub fn options(&self) -> &VoxtralTtsOptions {
88 &self.options
89 }
90
91 pub fn model_dir(&self) -> &Path {
92 self.store.model_dir()
93 }
94
95 pub fn decode_codes_to_pcm(&self, codes: &[u32], n_frames: usize) -> Result<Vec<f32>> {
96 self.codec.decode_codes(codes, n_frames)
97 }
98
99 pub fn decode_codes_file(&self, codes_path: &Path, out_wav: &Path) -> Result<()> {
100 let (codes, n_frames) = parse_codes_file(codes_path)?;
101 let pcm = self.decode_codes_to_pcm(&codes, n_frames)?;
102 write_wav_mono(
103 out_wav,
104 &pcm,
105 self.cfg.audio_config.codec_args.sampling_rate as u32,
106 )
107 }
108
109 pub fn voice_clone_support(&self) -> VoiceCloneSupport {
110 voice_clone_support(&self.store)
111 }
112
113 pub fn synthesize_native_with_voice(
115 &mut self,
116 prompt_tokens: &[u32],
117 voice_emb: &VoiceEmbedding,
118 out_wav: &Path,
119 gen_cfg: &GenerationConfig,
120 ) -> Result<()> {
121 if self.native.is_none() {
122 self.native = Some(NativeTtsEngine::open(
123 &self.store,
124 &self.cfg,
125 &self.options,
126 )?);
127 }
128 let engine = self.native.as_mut().unwrap();
129 let pcm = engine.synthesize(prompt_tokens, voice_emb, gen_cfg)?;
130 write_wav_mono(
131 out_wav,
132 &pcm,
133 self.cfg.audio_config.codec_args.sampling_rate as u32,
134 )
135 }
136
137 pub fn synthesize_cloned_from_wav(
139 &mut self,
140 prompt_tokens: &[u32],
141 reference_wav: &Path,
142 out_wav: &Path,
143 gen_cfg: &GenerationConfig,
144 ) -> Result<()> {
145 let voice_emb = encode_reference_wav(&self.store, &self.cfg, reference_wav, "cloned")?;
146 self.synthesize_native_with_voice(prompt_tokens, &voice_emb, out_wav, gen_cfg)
147 }
148
149 pub fn synthesize_cloned_with_text(
151 &mut self,
152 text: &str,
153 reference_wav: &Path,
154 out_wav: &Path,
155 gen_cfg: &GenerationConfig,
156 ) -> Result<()> {
157 let voice_emb = encode_reference_wav(&self.store, &self.cfg, reference_wav, "cloned")?;
158 let tok = SpeechTokenizer::from_model_dir(self.model_dir())?;
159 let prompt_tokens = tok.encode_speech_with_n_audio(text, voice_emb.n_tokens as u32)?;
160 self.synthesize_native_with_voice(&prompt_tokens, &voice_emb, out_wav, gen_cfg)
161 }
162
163 pub fn synthesize_native(
165 &mut self,
166 prompt_tokens: &[u32],
167 voice: &str,
168 out_wav: &Path,
169 gen_cfg: &GenerationConfig,
170 ) -> Result<()> {
171 let hidden = self.cfg.text_config.hidden_size;
172 let voice_emb = resolve_preset_voice(self.model_dir(), voice, hidden)?;
173 self.synthesize_native_with_voice(prompt_tokens, &voice_emb, out_wav, gen_cfg)
174 }
175
176 pub fn synthesize_native_from_token_file(
177 &mut self,
178 prompt_tokens_path: &Path,
179 voice: &str,
180 out_wav: &Path,
181 gen_cfg: &GenerationConfig,
182 ) -> Result<()> {
183 let tokens = load_prompt_tokens(prompt_tokens_path)?;
184 self.synthesize_native(&tokens, voice, out_wav, gen_cfg)
185 }
186
187 pub fn synthesize_cloned_from_token_file(
188 &mut self,
189 prompt_tokens_path: &Path,
190 reference_wav: &Path,
191 out_wav: &Path,
192 gen_cfg: &GenerationConfig,
193 ) -> Result<()> {
194 let tokens = load_prompt_tokens(prompt_tokens_path)?;
195 self.synthesize_cloned_from_wav(&tokens, reference_wav, out_wav, gen_cfg)
196 }
197
198 pub fn synthesize_native_with_embedding_file(
199 &mut self,
200 prompt_tokens_path: &Path,
201 voice_embedding: &Path,
202 out_wav: &Path,
203 gen_cfg: &GenerationConfig,
204 ) -> Result<()> {
205 let tokens = load_prompt_tokens(prompt_tokens_path)?;
206 let hidden = self.cfg.text_config.hidden_size;
207 let voice_emb = VoiceEmbedding::load_f32(voice_embedding, "custom", hidden)?;
208 self.synthesize_native_with_voice(&tokens, &voice_emb, out_wav, gen_cfg)
209 }
210
211 pub fn encode_reference_to_file(
212 &self,
213 reference_wav: &Path,
214 out_f32: &Path,
215 voice_name: &str,
216 ) -> Result<VoiceEmbedding> {
217 crate::voice_clone::encode_reference_wav_to_file(
218 &self.store,
219 &self.cfg,
220 reference_wav,
221 out_f32,
222 voice_name,
223 )
224 }
225
226 pub fn bench_native_profiled(
228 &mut self,
229 prompt_tokens: &[u32],
230 voice: &str,
231 gen_cfg: &GenerationConfig,
232 options: &VoxtralTtsOptions,
233 ) -> Result<VoxtralTtsBenchReport> {
234 options.validate()?;
235 let hidden = self.cfg.text_config.hidden_size;
236 let voice_emb = resolve_preset_voice(self.model_dir(), voice, hidden)?;
237 let mut engine = NativeTtsEngine::open(&self.store, &self.cfg, options)?;
238 let (_, report) = engine.synthesize_profiled(prompt_tokens, &voice_emb, gen_cfg)?;
239 Ok(report)
240 }
241
242 pub fn synthesize_native_codes(
243 &mut self,
244 prompt_tokens: &[u32],
245 voice: &str,
246 gen_cfg: &GenerationConfig,
247 ) -> Result<Vec<u32>> {
248 let hidden = self.cfg.text_config.hidden_size;
249 let voice_emb = resolve_preset_voice(self.model_dir(), voice, hidden)?;
250 if self.native.is_none() {
251 self.native = Some(NativeTtsEngine::open(
252 &self.store,
253 &self.cfg,
254 &self.options,
255 )?);
256 }
257 let engine = self.native.as_mut().unwrap();
258 engine.synthesize_codes(prompt_tokens, &voice_emb, gen_cfg)
259 }
260}
261
262fn resolve_preset_voice(model_dir: &Path, voice: &str, hidden: usize) -> Result<VoiceEmbedding> {
263 if PRESET_VOICES.contains(&voice) {
264 return VoiceEmbedding::load(model_dir, voice, hidden);
265 }
266 bail!(
267 "unknown preset voice {voice:?}; expected one of {} (or use --reference-wav / --voice-embedding)",
268 PRESET_VOICES.join(", ")
269 )
270}
271
272pub fn parse_codes_file(path: &Path) -> Result<(Vec<u32>, usize)> {
273 let text = std::fs::read_to_string(path).with_context(|| format!("read {}", path.display()))?;
274 let mut lines = text.lines();
275 let n_frames: usize = lines
276 .next()
277 .ok_or_else(|| anyhow::anyhow!("empty codes file"))?
278 .parse()
279 .context("parse frame count")?;
280 let body = lines.next().unwrap_or_default();
281 let codes: Vec<u32> = body
282 .split_whitespace()
283 .map(|s| s.parse().context("parse code"))
284 .collect::<Result<_>>()?;
285 Ok((codes, n_frames))
286}
287
288pub fn write_wav_mono(path: &Path, pcm: &[f32], sample_rate: u32) -> Result<()> {
289 let mut bytes = Vec::with_capacity(44 + pcm.len() * 2);
290 bytes.extend_from_slice(b"RIFF");
291 let data_bytes = (pcm.len() * 2 + 36) as u32;
292 bytes.extend_from_slice(&data_bytes.to_le_bytes());
293 bytes.extend_from_slice(b"WAVEfmt ");
294 bytes.extend_from_slice(&16u32.to_le_bytes());
295 bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&sample_rate.to_le_bytes());
298 bytes.extend_from_slice(&(sample_rate * 2).to_le_bytes());
299 bytes.extend_from_slice(&2u16.to_le_bytes());
300 bytes.extend_from_slice(&16u16.to_le_bytes());
301 bytes.extend_from_slice(b"data");
302 bytes.extend_from_slice(&((pcm.len() * 2) as u32).to_le_bytes());
303 for &s in pcm {
304 let v = (s.clamp(-1.0, 1.0) * i16::MAX as f32) as i16;
305 bytes.extend_from_slice(&v.to_le_bytes());
306 }
307 std::fs::write(path, bytes).with_context(|| format!("write {}", path.display()))?;
308 Ok(())
309}