1#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
7
8use std::collections::HashMap;
9use std::sync::{Arc, OnceLock};
10
11use async_trait::async_trait;
12use candle_core::{DType, Device as CandleDevice, Tensor};
13use candle_nn::VarBuilder;
14use ferrum_interfaces::{
15 model_executor::{
16 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
17 PrefillInput, PrefillOutput,
18 },
19 ModelExecutor, TensorRef,
20};
21use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
22use tracing::info;
23
24use super::common;
25use crate::multimodal::qwen3_tts::{Qwen3TTSTalker, SubTalker, TalkerConfig};
26use crate::multimodal::qwen3_tts_backbone::TalkerBackboneBackend;
27use crate::multimodal::qwen3_tts_vocoder::{Qwen3TTSVocoder, VocoderConfig};
28use crate::multimodal::speaker_encoder::{mel_spectrogram_speaker_encoder, SpeakerEncoder};
29use crate::multimodal::speech_tokenizer_encoder::SpeechTokenizerEncoder;
30use ferrum_quantization::NativeSafetensorsLoader;
31
32#[cfg(feature = "cuda")]
36fn install_cuda_backend_overrides(
37 cfg: &TalkerConfig,
38 model_dir: &std::path::Path,
39 talker: &mut Qwen3TTSTalker,
40 sub_talker: &mut SubTalker,
41) -> Result<()> {
42 use ferrum_kernels::backend::cuda::CudaBackend;
43 let loader: NativeSafetensorsLoader<CudaBackend> = NativeSafetensorsLoader::open(model_dir)?;
44 let talker_bb = TalkerBackboneBackend::<CudaBackend>::new(cfg, &loader)?;
45 talker.set_backend_override(Box::new(talker_bb));
46 let sub_bb = TalkerBackboneBackend::<CudaBackend>::new_code_predictor(cfg, &loader)?;
47 sub_talker.set_backend_override(Box::new(sub_bb));
48 Ok(())
49}
50
51const SAMPLE_RATE: usize = 24000;
54const MAX_CODEC_TOKENS: usize = 2000;
55
56const TEMPERATURE: f32 = 0.9;
59const TOP_K: usize = 50;
60const REPETITION_PENALTY: f32 = 1.05;
61
62#[derive(Debug, Clone, PartialEq)]
63struct TtsRuntimeEnv {
64 tts_temperature: f32,
65 st_temperature: Option<f32>,
66 ref_pcm: Option<String>,
67 ref_codes: Option<String>,
68 min_frames: Option<usize>,
69}
70
71impl TtsRuntimeEnv {
72 fn from_env() -> Self {
73 Self::from_env_vars(std::env::vars())
74 }
75
76 fn from_env_vars<I, K, V>(vars: I) -> Self
77 where
78 I: IntoIterator<Item = (K, V)>,
79 K: AsRef<str>,
80 V: Into<String>,
81 {
82 let mut tts_temperature = None;
83 let mut st_temperature = None;
84 let mut ref_pcm = None;
85 let mut ref_codes = None;
86 let mut min_frames = None;
87
88 for (key, value) in vars {
89 let value = value.into();
90 match key.as_ref() {
91 "FERRUM_TTS_TEMP" => tts_temperature = value.parse::<f32>().ok(),
92 "FERRUM_ST_TEMP" => st_temperature = value.parse::<f32>().ok(),
93 "FERRUM_REF_PCM" => ref_pcm = Some(value),
94 "FERRUM_REF_CODES" => ref_codes = Some(value),
95 "FERRUM_TTS_MIN_FRAMES" => min_frames = value.parse::<usize>().ok(),
96 _ => {}
97 }
98 }
99
100 Self {
101 tts_temperature: tts_temperature.unwrap_or(TEMPERATURE),
102 st_temperature,
103 ref_pcm,
104 ref_codes,
105 min_frames,
106 }
107 }
108
109 fn st_temperature(&self) -> f32 {
110 self.st_temperature.unwrap_or(self.tts_temperature)
111 }
112}
113
114fn tts_runtime_env() -> &'static TtsRuntimeEnv {
115 static CONFIG: OnceLock<TtsRuntimeEnv> = OnceLock::new();
116 CONFIG.get_or_init(TtsRuntimeEnv::from_env)
117}
118
119fn tts_temperature() -> f32 {
120 tts_runtime_env().tts_temperature
121}
122
123fn st_temperature() -> f32 {
124 tts_runtime_env().st_temperature()
125}
126
127pub struct TtsModelExecutor {
129 talker: Qwen3TTSTalker,
130 sub_talker: SubTalker,
131 vocoder: Qwen3TTSVocoder,
132 text_tokenizer: tokenizers::Tokenizer,
133 config: TalkerConfig,
134 info: ModelInfo,
135 speaker_encoder: Option<SpeakerEncoder>,
136 speech_tokenizer_encoder: Option<SpeechTokenizerEncoder>,
137}
138
139impl TtsModelExecutor {
140 pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
146 let dir = std::path::Path::new(model_path);
147
148 let config_json: serde_json::Value = {
150 let config_path = dir.join("config.json");
151 let data = std::fs::read_to_string(&config_path)
152 .map_err(|e| FerrumError::model(format!("read config.json: {e}")))?;
153 serde_json::from_str(&data)
154 .map_err(|e| FerrumError::model(format!("parse config.json: {e}")))?
155 };
156 let config = TalkerConfig::from_json(&config_json)?;
157
158 let text_tokenizer = load_bpe_tokenizer(dir)?;
160
161 let talker_weights = find_safetensor_files(dir, "model")?;
163 let talker_vb = unsafe {
164 VarBuilder::from_mmaped_safetensors(&talker_weights, dtype, &device)
165 .map_err(|e| FerrumError::model(format!("load talker weights: {e}")))?
166 };
167 let mut talker = Qwen3TTSTalker::load(&config, talker_vb.clone(), device.clone())?;
168
169 let mut sub_talker = SubTalker::load(&config, talker_vb.clone(), device.clone())?;
171
172 let spk_enc_dim = config_json
174 .get("speaker_encoder_config")
175 .and_then(|c| c.get("enc_dim"))
176 .and_then(|v| v.as_u64())
177 .unwrap_or(1024) as usize;
178 let speaker_encoder =
179 SpeakerEncoder::load_with_dim(talker_vb.pp("speaker_encoder"), spk_enc_dim)
180 .map_err(|e| {
181 tracing::warn!("Speaker encoder not available: {e}");
182 e
183 })
184 .ok();
185
186 let vocoder_dir = dir.join("speech_tokenizer");
188 let vocoder_weights = find_safetensor_files(&vocoder_dir, "model")?;
189 let vocoder_vb = unsafe {
190 VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &device)
191 .map_err(|e| FerrumError::model(format!("load vocoder weights: {e}")))?
192 };
193 let vocoder_config = VocoderConfig::default();
194 let vocoder = Qwen3TTSVocoder::load(&vocoder_config, vocoder_vb.clone())?;
195
196 let speech_tokenizer_encoder = if vocoder_dir.join("config.json").exists() {
200 let cpu_vb = unsafe {
201 VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &CandleDevice::Cpu)
202 .map_err(|e| FerrumError::model(format!("load encoder cpu: {e}")))?
203 };
204 SpeechTokenizerEncoder::load(cpu_vb.pp("encoder"), CandleDevice::Cpu)
205 .map_err(|e| {
206 tracing::warn!("Speech tokenizer encoder not available: {e}");
207 e
208 })
209 .ok()
210 } else {
211 None
212 };
213
214 #[cfg(feature = "cuda")]
223 if matches!(&device, CandleDevice::Cuda(_)) {
224 match install_cuda_backend_overrides(&config, dir, &mut talker, &mut sub_talker) {
225 Ok(()) => {
226 tracing::info!(
227 "TtsModelExecutor: Backend<CudaBackend> installed for Talker + SubTalker"
228 );
229 }
230 Err(e) => {
231 tracing::warn!(
232 "TtsModelExecutor: Backend<CudaBackend> install failed ({e}); \
233 falling back to candle/fused path (CUDA voice-clone may produce garbage)"
234 );
235 }
236 }
237 }
238
239 let info = ModelInfo {
240 model_id: ferrum_types::ModelId(model_path.to_string()),
241 model_type: ModelType::Custom("qwen3-tts".to_string()),
242 hidden_size: config.hidden_size,
243 vocab_size: config.vocab_size,
244 num_layers: config.num_hidden_layers,
245 num_heads: config.num_attention_heads,
246 num_kv_heads: config.num_key_value_heads,
247 num_parameters: 0,
248 max_sequence_length: config.max_position_embeddings,
249 device: match &device {
250 CandleDevice::Cpu => Device::CPU,
251 CandleDevice::Cuda(_) => Device::CUDA(0),
252 #[cfg(any(target_os = "macos", target_os = "ios"))]
253 CandleDevice::Metal(_) => Device::Metal,
254 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
255 CandleDevice::Metal(_) => Device::CPU,
256 },
257 dtype: match dtype {
258 DType::F32 => DataType::FP32,
259 DType::F16 => DataType::FP16,
260 DType::BF16 => DataType::BF16,
261 _ => DataType::FP32,
262 },
263 version: None,
264 license: None,
265 metadata: HashMap::new(),
266 };
267
268 info!(
269 "TtsModelExecutor: {} (hidden={}, layers={}, codec_groups={})",
270 model_path, config.hidden_size, config.num_hidden_layers, config.num_code_groups,
271 );
272
273 Ok(Self {
274 talker,
275 sub_talker,
276 vocoder,
277 text_tokenizer,
278 config,
279 info,
280 speaker_encoder,
281 speech_tokenizer_encoder,
282 })
283 }
284
285 pub fn synthesize(&mut self, text: &str, language: &str) -> Result<Vec<f32>> {
293 self.talker.reset();
294
295 let device = self.talker.device().clone();
296
297 let encoding = self
299 .text_tokenizer
300 .encode(text, false)
301 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
302 let content_ids: Vec<u32> = encoding.get_ids().to_vec();
303
304 if content_ids.is_empty() {
305 return Err(FerrumError::model("empty text after tokenization"));
306 }
307
308 info!("TTS: content tokens = {}", content_ids.len());
309
310 let codec_eos = self.config.codec_eos_token_id;
311 let tts_pad = self.config.tts_pad_token_id;
312 let tts_bos = self.config.tts_bos_token_id;
313 let tts_eos = self.config.tts_eos_token_id;
314
315 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
317 let t = Tensor::new(ids, &device)
318 .and_then(|t| t.unsqueeze(0))
319 .map_err(|e| FerrumError::model(format!("text ids: {e}")))?;
320 self.talker.embed_text(&t)
321 };
322 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
323 let t = Tensor::new(ids, &device)
324 .and_then(|t| t.unsqueeze(0))
325 .map_err(|e| FerrumError::model(format!("codec ids: {e}")))?;
326 self.talker.embed_codec(&t)
327 };
328
329 let im_start_id = 151644u32; let assistant_id = 77091u32; let newline_id = 198u32; let role_prefix_ids = [im_start_id, assistant_id, newline_id];
337
338 info!(
339 "TTS: role_prefix={} content={} tokens",
340 role_prefix_ids.len(),
341 content_ids.len()
342 );
343
344 let role_embed = embed_text_ids(&role_prefix_ids)?;
346
347 let resolved_lang = if language == "auto" {
349 "chinese"
350 } else {
351 language
352 };
353 let language_id = self
354 .config
355 .codec_language_id
356 .get(&resolved_lang.to_lowercase());
357 let codec_prefix_ids = if let Some(&lang_id) = language_id {
358 vec![
359 self.config.codec_think_id,
360 self.config.codec_think_bos_id,
361 lang_id,
362 self.config.codec_think_eos_id,
363 ]
364 } else {
365 vec![
366 self.config.codec_nothink_id,
367 self.config.codec_think_bos_id,
368 self.config.codec_think_eos_id,
369 ]
370 };
371 let speaker_token = if resolved_lang == "chinese" {
374 3065u32
375 } else {
376 3061u32
377 };
378 let codec_full = {
379 let mut v = codec_prefix_ids.clone();
380 v.push(speaker_token);
381 v.push(self.config.codec_pad_id);
382 v.push(self.config.codec_bos_id);
383 v
384 };
385 let codec_embed = embed_codec_ids(&codec_full)?;
386
387 let n_codec = codec_full.len();
389 let mut tts_prefix_ids = vec![tts_pad; n_codec - 1];
390 tts_prefix_ids.push(tts_bos);
391 let tts_prefix_embed = embed_text_ids(&tts_prefix_ids)?;
392
393 let codec_first = codec_embed
395 .narrow(1, 0, n_codec - 1)
396 .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
397 let n_prefix = n_codec - 1; let codec_prefix_part = codec_embed
406 .narrow(1, 0, n_prefix)
407 .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
408
409 let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
411 tts_text_prefix_ids.push(tts_bos);
412 let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
413
414 let codec_hidden = (&tts_text_embed + &codec_prefix_part)
415 .map_err(|e| FerrumError::model(format!("prefix sum: {e}")))?;
416
417 let codec_bos_embed = codec_embed
419 .narrow(1, n_prefix, 1)
420 .map_err(|e| FerrumError::model(format!("codec bos: {e}")))?;
421
422 let first_text_combined = if !content_ids.is_empty() {
424 let first_text_embed = embed_text_ids(&content_ids[..1])?;
425 (&first_text_embed + &codec_bos_embed)
426 .map_err(|e| FerrumError::model(format!("first text+bos: {e}")))?
427 } else {
428 codec_bos_embed.clone()
429 };
430
431 let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
433 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
434
435 let plen = prefill_embeds.dim(1).unwrap_or(0);
436 info!("TTS: prefill_len = {}", plen);
437 if let Ok(v) = prefill_embeds
439 .narrow(0, 0, 1)
440 .and_then(|t| t.narrow(1, 0, 1))
441 .and_then(|t| t.narrow(2, 0, 5))
442 .and_then(|t| t.flatten_all())
443 .and_then(|t| t.to_vec1::<f32>())
444 {
445 info!(" prefill_input pos0[:5] = {:?}", v);
446 }
447 if plen > 0 {
448 if let Ok(v) = prefill_embeds
449 .narrow(0, 0, 1)
450 .and_then(|t| t.narrow(1, plen - 1, 1))
451 .and_then(|t| t.narrow(2, 0, 5))
452 .and_then(|t| t.flatten_all())
453 .and_then(|t| t.to_vec1::<f32>())
454 {
455 info!(" prefill_input pos-1[:5] = {:?}", v);
456 }
457 }
458
459 let mut trailing_ids: Vec<u32> = if content_ids.len() > 1 {
461 content_ids[1..].to_vec()
462 } else {
463 Vec::new()
464 };
465 trailing_ids.push(tts_eos);
466 let trailing_text_embeds = embed_text_ids(&trailing_ids)?;
467 let trailing_text_len = trailing_text_embeds
468 .dim(1)
469 .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
470 let tts_pad_embed = embed_text_ids(&[tts_pad])?;
471
472 info!("TTS: trailing_text_len = {}", trailing_text_len);
473
474 let mut hidden = self.talker.forward_step(&prefill_embeds)?;
476
477 let current_logits = self.talker.logits(
479 &hidden
480 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
481 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
482 )?;
483
484 let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
486 let mut current_logits = current_logits;
487
488 let suppress_start = self.config.vocab_size.saturating_sub(1024);
490 let suppress_end = self.config.vocab_size;
491 let mut generated_tokens: Vec<u32> = Vec::new();
492
493 for step in 0..MAX_CODEC_TOKENS {
494 let mut logits_vec = logits_to_vec(¤t_logits)?;
496 for i in suppress_start..suppress_end.min(logits_vec.len()) {
498 if i as u32 != codec_eos {
499 logits_vec[i] = f32::NEG_INFINITY;
500 }
501 }
502 for &prev_tok in &generated_tokens {
504 let idx = prev_tok as usize;
505 if idx < logits_vec.len() {
506 if logits_vec[idx] > 0.0 {
507 logits_vec[idx] /= REPETITION_PENALTY;
508 } else {
509 logits_vec[idx] *= REPETITION_PENALTY;
510 }
511 }
512 }
513 let next_token =
514 sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
515
516 if step < 10 {
517 let argmax_tok = logits_vec
519 .iter()
520 .enumerate()
521 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
522 .map(|(i, v)| (i, *v))
523 .unwrap_or((0, 0.0));
524 info!(
525 "TOKEN step={} sampled={} argmax=({}, {:.2})",
526 step, next_token, argmax_tok.0, argmax_tok.1
527 );
528 }
529
530 generated_tokens.push(next_token);
531
532 if next_token == codec_eos {
534 info!("TTS: codec EOS at step {}", step);
535 break;
536 }
537
538 let last_hidden = hidden
540 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
541 .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
542
543 let token_tensor = Tensor::new(&[next_token], &device)
545 .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
546 .unsqueeze(0)
547 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
548 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
549
550 let st_t0 = std::time::Instant::now();
552 let extra_codes = self.sub_talker.predict(
553 &last_hidden,
554 &first_codec_embed,
555 st_temperature(),
556 TOP_K,
557 )?;
558 if step == 0 {
559 info!(
560 " SubTalker: {:.1}ms",
561 st_t0.elapsed().as_secs_f64() * 1000.0
562 );
563 }
564
565 let mut frame_codes = vec![next_token];
566 frame_codes.extend_from_slice(&extra_codes);
567 all_codec_tokens.push(frame_codes);
568
569 let mut combined_embed = first_codec_embed.clone();
572 for (i, &code) in extra_codes.iter().enumerate() {
573 let code_t = Tensor::new(&[code], &device)
574 .and_then(|t| t.unsqueeze(0))
575 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
576 let sub_embed = code_t
577 .apply(&self.sub_talker.codec_embeddings[i])
578 .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
579 combined_embed = (combined_embed + sub_embed)
580 .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
581 }
582
583 if step == 0 {
584 if let Ok(v) = combined_embed
585 .flatten_all()
586 .and_then(|t| t.narrow(0, 0, 5))
587 .and_then(|t| t.to_vec1::<f32>())
588 {
589 info!("STEP0 codec_sum[:5] = {:?} (before trailing)", v);
590 }
591 }
592 if step < trailing_text_len {
597 let trail = trailing_text_embeds
598 .narrow(1, step, 1)
599 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
600 combined_embed = (combined_embed + trail)
601 .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
602 } else {
603 combined_embed = (combined_embed + &tts_pad_embed)
604 .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
605 }
606
607 if step == 0 {
609 if let Ok(v) = first_codec_embed
610 .flatten_all()
611 .and_then(|t| t.narrow(0, 0, 5))
612 .and_then(|t| t.to_vec1::<f32>())
613 {
614 info!("STEP0 semantic[:5] = {:?}", v);
615 }
616 if let Ok(v) = combined_embed
617 .flatten_all()
618 .and_then(|t| t.narrow(0, 0, 5))
619 .and_then(|t| t.to_vec1::<f32>())
620 {
621 info!("STEP0 combined[:5] = {:?}", v);
622 }
623 }
624 let tk_t0 = std::time::Instant::now();
626 hidden = self.talker.forward_step(&combined_embed)?;
627 current_logits = self.talker.logits(&hidden)?;
628 if step == 0 {
629 info!(
630 " Talker step: {:.1}ms",
631 tk_t0.elapsed().as_secs_f64() * 1000.0
632 );
633 }
634 }
635
636 if all_codec_tokens.is_empty() {
637 return Err(FerrumError::model("no codec tokens generated"));
638 }
639
640 info!("TTS: generated {} codec frames", all_codec_tokens.len());
641
642 let num_frames = all_codec_tokens.len();
644 let num_groups = self.config.num_code_groups;
645 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
646 for (t, frame) in all_codec_tokens.iter().enumerate() {
647 for (g, &code) in frame.iter().enumerate() {
648 flat_codes[g * num_frames + t] = code;
649 }
650 }
651
652 let codebook_size = 2048u32;
655 for code in &mut flat_codes {
656 if *code >= codebook_size {
657 *code = 0; }
659 }
660
661 let codes_tensor = Tensor::new(&flat_codes[..], &device)
662 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
663 .reshape((1, num_groups, num_frames))
664 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
665
666 let waveform = self.vocoder.decode(&codes_tensor)?;
668
669 let samples: Vec<f32> = waveform
671 .squeeze(0)
672 .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
673 .squeeze(0)
674 .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
675 .to_vec1()
676 .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
677
678 info!(
679 "TTS: waveform {} samples ({:.2}s @ {}Hz)",
680 samples.len(),
681 samples.len() as f64 / SAMPLE_RATE as f64,
682 SAMPLE_RATE,
683 );
684
685 Ok(samples)
686 }
687
688 fn decode_frames(&mut self, frames: &[Vec<u32>], device: &CandleDevice) -> Result<Vec<f32>> {
691 let num_frames = frames.len();
692 if num_frames == 0 {
693 return Ok(vec![]);
694 }
695 let num_groups = self.config.num_code_groups;
696 let codebook_size = 2048u32;
697
698 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
699 for (t, frame) in frames.iter().enumerate() {
700 for (g, &code) in frame.iter().take(num_groups).enumerate() {
701 flat_codes[g * num_frames + t] = if code >= codebook_size { 0 } else { code };
702 }
703 }
704
705 let codes_tensor = Tensor::new(&flat_codes[..], device)
706 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
707 .reshape((1, num_groups, num_frames))
708 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
709
710 let waveform = self.vocoder.decode(&codes_tensor)?;
711 waveform
712 .squeeze(0)
713 .and_then(|t| t.squeeze(0))
714 .and_then(|t| t.to_vec1())
715 .map_err(|e| FerrumError::model(format!("waveform extract: {e}")))
716 }
717
718 pub fn synthesize_streaming<F: FnMut(usize, &[f32])>(
723 &mut self,
724 text: &str,
725 language: &str,
726 chunk_frames: usize,
727 mut on_chunk: F,
728 ) -> Result<Vec<Vec<f32>>> {
729 self.talker.reset();
732 let device = self.talker.device().clone();
733
734 let encoding = self
735 .text_tokenizer
736 .encode(text, false)
737 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
738 let content_ids: Vec<u32> = encoding.get_ids().to_vec();
739 if content_ids.is_empty() {
740 return Err(FerrumError::model("empty text after tokenization"));
741 }
742
743 let codec_eos = self.config.codec_eos_token_id;
744 let tts_pad = self.config.tts_pad_token_id;
745 let tts_bos = self.config.tts_bos_token_id;
746 let tts_eos = self.config.tts_eos_token_id;
747
748 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
750 let t = Tensor::new(ids, &device)
751 .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
752 .unsqueeze(0)
753 .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
754 self.talker.embed_text(&t)
755 };
756 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
757 let t = Tensor::new(ids, &device)
758 .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
759 .unsqueeze(0)
760 .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
761 self.talker.embed_codec(&t)
762 };
763
764 let resolved_lang = if language.eq_ignore_ascii_case("auto") {
766 "chinese"
767 } else {
768 language
769 };
770 let language_id = self
771 .config
772 .codec_language_id
773 .get(&resolved_lang.to_lowercase());
774 let codec_prefix_ids = if let Some(&lang_id) = language_id {
775 vec![
776 self.config.codec_think_id,
777 self.config.codec_think_bos_id,
778 lang_id,
779 self.config.codec_think_eos_id,
780 ]
781 } else {
782 vec![
783 self.config.codec_nothink_id,
784 self.config.codec_think_bos_id,
785 self.config.codec_think_eos_id,
786 ]
787 };
788 let speaker_token = if resolved_lang == "chinese" {
789 3065u32
790 } else {
791 3061u32
792 };
793 let mut codec_ids = codec_prefix_ids;
794 codec_ids.push(speaker_token);
795 codec_ids.push(self.config.codec_pad_id);
796 codec_ids.push(self.config.codec_bos_id);
797 let codec_embed = embed_codec_ids(&codec_ids)?;
798 let n_codec = codec_embed
799 .dim(1)
800 .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
801 let n_prefix = n_codec - 1;
802 let codec_prefix_part = codec_embed
803 .narrow(1, 0, n_prefix)
804 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
805
806 let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
807 tts_text_prefix_ids.push(tts_bos);
808 let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
809 let codec_hidden = (&tts_text_embed + &codec_prefix_part)
810 .map_err(|e| FerrumError::model(format!("sum: {e}")))?;
811 let codec_bos_embed = codec_embed
812 .narrow(1, n_prefix, 1)
813 .map_err(|e| FerrumError::model(format!("bos: {e}")))?;
814
815 let role_ids: &[u32] = &[151644, 77091, 198]; let role_embed = embed_text_ids(role_ids)?;
817
818 let first_text_combined = if !content_ids.is_empty() {
819 let first_text_embed = embed_text_ids(&content_ids[..1])?;
820 (&first_text_embed + &codec_bos_embed)
821 .map_err(|e| FerrumError::model(format!("first: {e}")))?
822 } else {
823 codec_bos_embed.clone()
824 };
825
826 let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
827 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
828
829 let trailing_text_embeds = if content_ids.len() > 1 {
831 let remaining = embed_text_ids(&content_ids[1..])?;
832 let eos = embed_text_ids(&[tts_eos])?;
833 Tensor::cat(&[&remaining, &eos], 1)
834 .map_err(|e| FerrumError::model(format!("trailing: {e}")))?
835 } else {
836 embed_text_ids(&[tts_eos])?
837 };
838 let trailing_text_len = trailing_text_embeds
839 .dim(1)
840 .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
841 let tts_pad_embed = embed_text_ids(&[tts_pad])?;
842
843 let mut hidden = self.talker.forward_step(&prefill_embeds)?;
845 let mut current_logits = self.talker.logits(
846 &hidden
847 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
848 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
849 )?;
850
851 let suppress_start = self.config.vocab_size.saturating_sub(1024);
853 let suppress_end = self.config.vocab_size;
854 let mut generated_tokens: Vec<u32> = Vec::new();
855 let mut frame_buffer: Vec<Vec<u32>> = Vec::new();
856 let mut audio_chunks: Vec<Vec<f32>> = Vec::new();
857
858 for step in 0..MAX_CODEC_TOKENS {
859 let mut logits_vec = logits_to_vec(¤t_logits)?;
860 for i in suppress_start..suppress_end.min(logits_vec.len()) {
861 if i as u32 != codec_eos {
862 logits_vec[i] = f32::NEG_INFINITY;
863 }
864 }
865 for &prev_tok in &generated_tokens {
866 let idx = prev_tok as usize;
867 if idx < logits_vec.len() {
868 if logits_vec[idx] > 0.0 {
869 logits_vec[idx] /= REPETITION_PENALTY;
870 } else {
871 logits_vec[idx] *= REPETITION_PENALTY;
872 }
873 }
874 }
875 let next_token =
876 sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
877 generated_tokens.push(next_token);
878
879 if next_token == codec_eos {
880 info!("TTS streaming: EOS at step {}", step);
881 break;
882 }
883
884 let last_hidden = hidden
885 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
886 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
887 let token_tensor = Tensor::new(&[next_token], &device)
888 .and_then(|t| t.unsqueeze(0))
889 .map_err(|e| FerrumError::model(format!("tok: {e}")))?;
890 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
891
892 let extra_codes = self.sub_talker.predict(
893 &last_hidden,
894 &first_codec_embed,
895 st_temperature(),
896 TOP_K,
897 )?;
898
899 let mut frame = vec![next_token];
900 frame.extend_from_slice(&extra_codes);
901 frame_buffer.push(frame);
902
903 if frame_buffer.len() >= chunk_frames {
905 let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
906 on_chunk(audio_chunks.len(), &chunk_audio);
907 audio_chunks.push(chunk_audio);
908 frame_buffer.clear();
909 }
910
911 let mut combined_embed = first_codec_embed.clone();
913 for (i, &code) in extra_codes.iter().enumerate() {
914 let code_t = Tensor::new(&[code], &device)
915 .and_then(|t| t.unsqueeze(0))
916 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
917 let sub_embed = code_t
918 .apply(&self.sub_talker.codec_embeddings[i])
919 .map_err(|e| FerrumError::model(format!("sub: {e}")))?;
920 combined_embed = (combined_embed + sub_embed)
921 .map_err(|e| FerrumError::model(format!("add: {e}")))?;
922 }
923 if step < trailing_text_len {
924 let trail = trailing_text_embeds
925 .narrow(1, step, 1)
926 .map_err(|e| FerrumError::model(format!("trail: {e}")))?;
927 combined_embed = (combined_embed + trail)
928 .map_err(|e| FerrumError::model(format!("add trail: {e}")))?;
929 } else {
930 combined_embed = (combined_embed + &tts_pad_embed)
931 .map_err(|e| FerrumError::model(format!("pad: {e}")))?;
932 }
933
934 hidden = self.talker.forward_step(&combined_embed)?;
935 current_logits = self.talker.logits(&hidden)?;
936 }
937
938 if !frame_buffer.is_empty() {
940 let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
941 on_chunk(audio_chunks.len(), &chunk_audio);
942 audio_chunks.push(chunk_audio);
943 }
944
945 Ok(audio_chunks)
946 }
947
948 pub fn sample_rate(&self) -> usize {
950 SAMPLE_RATE
951 }
952
953 pub fn config(&self) -> &TalkerConfig {
954 &self.config
955 }
956
957 pub fn synthesize_voice_clone(
965 &mut self,
966 text: &str,
967 language: &str,
968 ref_audio_path: &str,
969 ref_text: &str,
970 ) -> Result<Vec<f32>> {
971 let device = self.talker.device().clone();
972
973 let ref_pcm = if let Some(path) = tts_runtime_env().ref_pcm.as_deref() {
975 let data = std::fs::read(&path)
976 .map_err(|e| FerrumError::model(format!("read ref pcm: {e}")))?;
977 let pcm: Vec<f32> = data
978 .chunks(4)
979 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
980 .collect();
981 info!("Loaded ref PCM override: {} samples", pcm.len());
982 pcm
983 } else {
984 crate::audio_processor::load_audio_at_rate(ref_audio_path, 24000)?
985 };
986 info!(
987 "TTS voice clone: loaded ref audio {} samples ({:.2}s)",
988 ref_pcm.len(),
989 ref_pcm.len() as f64 / 24000.0
990 );
991
992 let t0 = std::time::Instant::now();
993 let speaker_encoder = self
995 .speaker_encoder
996 .as_ref()
997 .ok_or_else(|| FerrumError::model("speaker encoder not loaded"))?;
998 let mel = mel_spectrogram_speaker_encoder(&ref_pcm);
999 let n_mel_frames = mel.len() / 128;
1000 let mel_tensor = Tensor::from_vec(mel, (1, n_mel_frames, 128), &device)
1003 .map_err(|e| FerrumError::model(format!("mel tensor: {e}")))?;
1004 let spk_embed = speaker_encoder.forward(&mel_tensor)?;
1005 info!(
1006 "Step 2 (speaker embed): {:.1}ms",
1007 t0.elapsed().as_secs_f64() * 1000.0
1008 );
1009 let spk_embed = spk_embed
1011 .unsqueeze(0)
1012 .map_err(|e| FerrumError::model(format!("spk unsqueeze(0): {e}")))?
1013 .unsqueeze(0)
1014 .map_err(|e| FerrumError::model(format!("spk unsqueeze(0) 2: {e}")))?;
1015
1016 let t1 = std::time::Instant::now();
1017 let speech_enc = self
1019 .speech_tokenizer_encoder
1020 .as_ref()
1021 .ok_or_else(|| FerrumError::model("speech tokenizer encoder not loaded"))?;
1022 let ref_codes = if let Some(path) = tts_runtime_env().ref_codes.as_deref() {
1024 let data = std::fs::read(&path)
1025 .map_err(|e| FerrumError::model(format!("read ref codes: {e}")))?;
1026 let u32s: Vec<u32> = data
1027 .chunks(4)
1028 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1029 .collect();
1030 let ncb = self.config.num_code_groups;
1031 let nframes = u32s.len() / ncb;
1032 info!(
1033 "Loaded pre-computed ref codes: {} frames from {}",
1034 nframes, path
1035 );
1036 u32s.chunks(ncb).map(|c| c.to_vec()).collect()
1037 } else {
1038 let codes = speech_enc.encode(&ref_pcm)?;
1039 info!(
1040 "Step 3 (speech tokenizer): {:.1}ms",
1041 t1.elapsed().as_secs_f64() * 1000.0
1042 );
1043 codes
1044 };
1045 let ref_frames = ref_codes.len();
1046 info!(
1047 "TTS voice clone: ref_frames={}, spk_embed loaded",
1048 ref_frames
1049 );
1050 for i in 0..ref_frames.min(5) {
1052 info!(" rust codec frame {}: {:?}", i, &ref_codes[i]);
1053 }
1054
1055 info!(
1056 "Step 3 (speech tokenizer): {:.1}ms",
1057 t1.elapsed().as_secs_f64() * 1000.0
1058 );
1059 let t2 = std::time::Instant::now();
1060 let chat_text = format!("<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n");
1062 let encoding = self
1063 .text_tokenizer
1064 .encode(chat_text.as_str(), false)
1065 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
1066 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
1067 let role_ids = &input_ids[..3];
1069 let text_content_ids = &input_ids[3..input_ids.len().saturating_sub(5)];
1070
1071 let ref_chat_text = format!("<|im_start|>assistant\n{ref_text}<|im_end|>\n");
1073 let ref_encoding = self
1074 .text_tokenizer
1075 .encode(ref_chat_text.as_str(), false)
1076 .map_err(|e| FerrumError::model(format!("tokenize ref: {e}")))?;
1077 let ref_ids: Vec<u32> = ref_encoding.get_ids().to_vec();
1078 let ref_text_ids = &ref_ids[3..ref_ids.len().saturating_sub(2)];
1080
1081 self.talker.reset();
1083
1084 let tts_bos = self.config.tts_bos_token_id;
1085 let tts_eos = self.config.tts_eos_token_id;
1086 let tts_pad = self.config.tts_pad_token_id;
1087 let codec_bos = self.config.codec_bos_id;
1088 let codec_eos = self.config.codec_eos_token_id;
1089 let codec_pad = self.config.codec_pad_id;
1090
1091 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
1093 let t = Tensor::new(ids, &device)
1094 .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
1095 .unsqueeze(0)
1096 .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
1097 self.talker.embed_codec(&t)
1098 };
1099 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
1100 let t = Tensor::new(ids, &device)
1101 .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
1102 .unsqueeze(0)
1103 .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
1104 self.talker.embed_text(&t)
1105 };
1106
1107 let tts_special = embed_text_ids(&[tts_bos, tts_eos, tts_pad])?;
1109 let tts_bos_embed = tts_special
1110 .narrow(1, 0, 1)
1111 .map_err(|e| FerrumError::model(format!("tts_bos narrow: {e}")))?;
1112 let tts_eos_embed = tts_special
1113 .narrow(1, 1, 1)
1114 .map_err(|e| FerrumError::model(format!("tts_eos narrow: {e}")))?;
1115 let tts_pad_embed = tts_special
1116 .narrow(1, 2, 1)
1117 .map_err(|e| FerrumError::model(format!("tts_pad narrow: {e}")))?;
1118
1119 let resolved_lang = if language.eq_ignore_ascii_case("auto") {
1121 "chinese"
1122 } else {
1123 language
1124 };
1125 let language_id = self
1126 .config
1127 .codec_language_id
1128 .get(&resolved_lang.to_lowercase());
1129
1130 let codec_prefix_ids = if let Some(&lang_id) = language_id {
1132 vec![
1133 self.config.codec_think_id,
1134 self.config.codec_think_bos_id,
1135 lang_id,
1136 self.config.codec_think_eos_id,
1137 ]
1138 } else {
1139 vec![
1140 self.config.codec_nothink_id,
1141 self.config.codec_think_bos_id,
1142 self.config.codec_think_eos_id,
1143 ]
1144 };
1145 let codec_prefix_embed = embed_codec_ids(&codec_prefix_ids)?;
1146
1147 let codec_suffix_embed = embed_codec_ids(&[codec_pad, codec_bos])?;
1149
1150 let codec_input = Tensor::cat(&[&codec_prefix_embed, &spk_embed, &codec_suffix_embed], 1)
1152 .map_err(|e| FerrumError::model(format!("codec_input cat: {e}")))?;
1153 let codec_len = codec_input
1154 .dim(1)
1155 .map_err(|e| FerrumError::model(format!("codec_len dim: {e}")))?;
1156
1157 let role_embed = embed_text_ids(role_ids)?;
1159
1160 let n_pads = codec_len - 2;
1162 let mut text_prefix_parts = Vec::new();
1163 for _ in 0..n_pads {
1164 text_prefix_parts.push(tts_pad_embed.clone());
1165 }
1166 text_prefix_parts.push(tts_bos_embed.clone());
1167 let text_prefix_refs: Vec<&Tensor> = text_prefix_parts.iter().collect();
1168 let text_prefix = Tensor::cat(&text_prefix_refs, 1)
1169 .map_err(|e| FerrumError::model(format!("text_prefix cat: {e}")))?;
1170 let codec_prefix_part = codec_input
1171 .narrow(1, 0, codec_len - 1)
1172 .map_err(|e| FerrumError::model(format!("codec prefix narrow: {e}")))?;
1173 let text_codec_prefix = (&text_prefix + &codec_prefix_part)
1174 .map_err(|e| FerrumError::model(format!("text+codec prefix sum: {e}")))?;
1175
1176 let prefill_embed = Tensor::cat(&[&role_embed, &text_codec_prefix], 1)
1178 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
1179
1180 let t3 = std::time::Instant::now();
1181
1182 let all_text_ids: Vec<u32> = ref_text_ids
1185 .iter()
1186 .chain(text_content_ids.iter())
1187 .copied()
1188 .collect();
1189 let text_embed = embed_text_ids(&all_text_ids)?;
1190 let text_embed_with_eos = Tensor::cat(&[&text_embed, &tts_eos_embed], 1)
1191 .map_err(|e| FerrumError::model(format!("text+eos cat: {e}")))?;
1192 let text_len = text_embed_with_eos
1193 .dim(1)
1194 .map_err(|e| FerrumError::model(format!("text_len dim: {e}")))?;
1195
1196 let t_codec_start = std::time::Instant::now();
1199 let ncg = self.config.num_code_groups;
1200 let first_codes: Vec<u32> = ref_codes.iter().map(|f| f[0]).collect();
1201 let codec_frames_cat = {
1202 let mut sum = embed_codec_ids(&first_codes)?;
1204 for cb in 0..(ncg - 1) {
1206 let codes: Vec<u32> = ref_codes.iter().map(|f| f[cb + 1]).collect();
1207 let codes_t = Tensor::new(codes.as_slice(), &device)
1208 .map_err(|e| FerrumError::model(format!("batch codes: {e}")))?
1209 .unsqueeze(0)
1210 .map_err(|e| FerrumError::model(format!("batch unsqueeze: {e}")))?;
1211 let sub_embed = codes_t
1212 .apply(&self.sub_talker.codec_embeddings[cb])
1213 .map_err(|e| FerrumError::model(format!("batch sub_embed: {e}")))?;
1214 sum =
1215 (sum + sub_embed).map_err(|e| FerrumError::model(format!("batch add: {e}")))?;
1216 }
1217 sum
1218 };
1219 info!(
1220 "Codec embedding: {:.1}ms ({} frames × {} codebooks)",
1221 t_codec_start.elapsed().as_secs_f64() * 1000.0,
1222 ref_codes.len(),
1223 ncg
1224 );
1225
1226 let t_merge = std::time::Instant::now();
1228 let codec_bos_for_icl = embed_codec_ids(&[codec_bos])?;
1229 let icl_codec = Tensor::cat(&[&codec_bos_for_icl, &codec_frames_cat], 1)
1230 .map_err(|e| FerrumError::model(format!("icl_codec cat: {e}")))?;
1231 let codec_icl_len = icl_codec
1232 .dim(1)
1233 .map_err(|e| FerrumError::model(format!("codec_icl_len dim: {e}")))?;
1234
1235 let icl_trailing: Tensor;
1237 let icl_embed: Tensor;
1238 if text_len > codec_icl_len {
1239 let text_part = text_embed_with_eos
1240 .narrow(1, 0, codec_icl_len)
1241 .map_err(|e| FerrumError::model(format!("text_part narrow: {e}")))?;
1242 icl_embed = (&text_part + &icl_codec)
1243 .map_err(|e| FerrumError::model(format!("text+codec sum: {e}")))?;
1244 icl_trailing = text_embed_with_eos
1245 .narrow(1, codec_icl_len, text_len - codec_icl_len)
1246 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1247 } else {
1248 let n_pad = codec_icl_len - text_len;
1249 let text_padded = if n_pad > 0 {
1250 let pad_block = tts_pad_embed
1251 .expand((1, n_pad, self.config.hidden_size))
1252 .map_err(|e| FerrumError::model(format!("pad expand: {e}")))?;
1253 Tensor::cat(&[&text_embed_with_eos, &pad_block], 1)
1254 .map_err(|e| FerrumError::model(format!("text_padded cat: {e}")))?
1255 } else {
1256 text_embed_with_eos.clone()
1257 };
1258 icl_embed = (&text_padded + &icl_codec)
1259 .map_err(|e| FerrumError::model(format!("padded+codec sum: {e}")))?;
1260 icl_trailing = tts_pad_embed.clone();
1261 }
1262 let trailing_text_len = icl_trailing
1263 .dim(1)
1264 .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
1265
1266 let _prefill_out = self.talker.forward_step(&prefill_embed)?;
1269 let t_icl = std::time::Instant::now();
1270 let icl_hidden = self.talker.forward_step(&icl_embed)?;
1271 let icl_len = icl_hidden
1272 .dim(1)
1273 .map_err(|e| FerrumError::model(format!("icl_hidden dim: {e}")))?;
1274 info!(
1275 "ICL block: {:.1}ms ({} tokens), trailing={}",
1276 t_icl.elapsed().as_secs_f64() * 1000.0,
1277 icl_len,
1278 trailing_text_len
1279 );
1280
1281 let mut hidden = icl_hidden;
1283 let hidden_len = hidden
1284 .dim(1)
1285 .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1286 let last_hidden = hidden
1287 .narrow(1, hidden_len - 1, 1)
1288 .map_err(|e| FerrumError::model(format!("narrow last: {e}")))?;
1289 if let Ok(v) = last_hidden.flatten_all().and_then(|t| t.to_vec1::<f32>()) {}
1290 let current_logits = self.talker.logits(&last_hidden)?;
1291 {}
1292
1293 let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
1295 let mut current_logits = current_logits;
1296
1297 let suppress_start = self.config.vocab_size.saturating_sub(1024);
1299 let suppress_end = self.config.vocab_size;
1300
1301 const ICL_REPETITION_PENALTY: f32 = 1.5;
1304 const ICL_FRAMES_PER_TOKEN: usize = 6;
1305 const ICL_MIN_FRAMES: usize = 75;
1306 let max_icl_tokens = ICL_MIN_FRAMES.max(text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1307 let mut generated_tokens: Vec<u32> = Vec::new();
1308
1309 for step in 0..max_icl_tokens {
1310 let mut logits_vec = logits_to_vec(¤t_logits)?;
1311 for i in suppress_start..suppress_end.min(logits_vec.len()) {
1313 if i as u32 != codec_eos {
1314 logits_vec[i] = f32::NEG_INFINITY;
1315 }
1316 }
1317 let min_frames = tts_runtime_env()
1321 .min_frames
1322 .unwrap_or_else(|| text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1323 if step < min_frames {
1324 if let Some(v) = logits_vec.get_mut(codec_eos as usize) {
1325 *v = f32::NEG_INFINITY;
1326 }
1327 }
1328 for &prev_tok in &generated_tokens {
1330 let idx = prev_tok as usize;
1331 if idx < logits_vec.len() {
1332 if logits_vec[idx] > 0.0 {
1333 logits_vec[idx] /= ICL_REPETITION_PENALTY;
1334 } else {
1335 logits_vec[idx] *= ICL_REPETITION_PENALTY;
1336 }
1337 }
1338 }
1339 let next_token = sample_token(
1340 &logits_vec,
1341 tts_temperature(),
1342 TOP_K,
1343 ICL_REPETITION_PENALTY,
1344 );
1345
1346 generated_tokens.push(next_token);
1347
1348 if next_token == codec_eos {
1349 info!("TTS voice clone: codec EOS at step {}", step);
1350 break;
1351 }
1352
1353 if generated_tokens.len() >= 6 {
1355 let n = generated_tokens.len();
1356 let mut is_repeat = false;
1357 for pat_len in 1..=4 {
1358 if n >= pat_len * 3 {
1359 let a = &generated_tokens[n - pat_len * 3..n - pat_len * 2];
1360 let b = &generated_tokens[n - pat_len * 2..n - pat_len];
1361 let c = &generated_tokens[n - pat_len..n];
1362 if a == b && b == c {
1363 is_repeat = true;
1364 break;
1365 }
1366 }
1367 }
1368 if is_repeat {
1369 info!(
1370 "TTS voice clone: repetition detected at step {}, stopping",
1371 step
1372 );
1373 break;
1374 }
1375 }
1376
1377 let cur_hidden_len = hidden
1378 .dim(1)
1379 .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1380 let last_hidden = hidden
1381 .narrow(1, cur_hidden_len - 1, 1)
1382 .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
1383
1384 let token_tensor = Tensor::new(&[next_token], &device)
1385 .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
1386 .unsqueeze(0)
1387 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
1388 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
1389
1390 let extra_codes = self.sub_talker.predict(
1391 &last_hidden,
1392 &first_codec_embed,
1393 st_temperature(),
1394 TOP_K,
1395 )?;
1396
1397 let mut frame_codes = vec![next_token];
1398 frame_codes.extend_from_slice(&extra_codes);
1399 all_codec_tokens.push(frame_codes);
1400
1401 let mut combined_embed = first_codec_embed.clone();
1403 for (i, &code) in extra_codes.iter().enumerate() {
1404 let code_t = Tensor::new(&[code], &device)
1405 .and_then(|t| t.unsqueeze(0))
1406 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
1407 let sub_embed = code_t
1408 .apply(&self.sub_talker.codec_embeddings[i])
1409 .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
1410 combined_embed = (combined_embed + sub_embed)
1411 .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
1412 }
1413
1414 if step < trailing_text_len {
1416 let trail = icl_trailing
1417 .narrow(1, step, 1)
1418 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1419 combined_embed = (combined_embed + trail)
1420 .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
1421 } else {
1422 combined_embed = (combined_embed + &tts_pad_embed)
1423 .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
1424 }
1425
1426 hidden = self.talker.forward_step(&combined_embed)?;
1427 current_logits = self.talker.logits(&hidden)?;
1428 }
1429
1430 if all_codec_tokens.is_empty() {
1431 return Err(FerrumError::model("no codec tokens generated"));
1432 }
1433 info!(
1434 "TTS voice clone: generated {} codec frames",
1435 all_codec_tokens.len()
1436 );
1437
1438 let mut all_codes_with_ref = ref_codes.clone();
1440 all_codes_with_ref.extend_from_slice(&all_codec_tokens);
1441
1442 let num_frames = all_codes_with_ref.len();
1443 let num_groups = self.config.num_code_groups;
1444
1445 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
1447 for (t, frame) in all_codes_with_ref.iter().enumerate() {
1448 for (g, &code) in frame.iter().take(num_groups).enumerate() {
1449 flat_codes[g * num_frames + t] = code;
1450 }
1451 }
1452
1453 let codebook_size = 2048u32;
1455 for code in &mut flat_codes {
1456 if *code >= codebook_size {
1457 *code = 0;
1458 }
1459 }
1460
1461 let codes_tensor = Tensor::new(&flat_codes[..], &device)
1462 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
1463 .reshape((1, num_groups, num_frames))
1464 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
1465
1466 let waveform = self.vocoder.decode(&codes_tensor)?;
1467
1468 let samples: Vec<f32> = waveform
1469 .squeeze(0)
1470 .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
1471 .squeeze(0)
1472 .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
1473 .to_vec1()
1474 .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
1475
1476 let ref_ratio = ref_frames as f64 / num_frames as f64;
1478 let cut = (ref_ratio * samples.len() as f64) as usize;
1479 let output_samples = samples[cut..].to_vec();
1480
1481 info!(
1482 "TTS voice clone: waveform {} samples ({:.2}s), trimmed ref {} samples",
1483 output_samples.len(),
1484 output_samples.len() as f64 / SAMPLE_RATE as f64,
1485 cut,
1486 );
1487
1488 Ok(output_samples)
1489 }
1490}
1491
1492fn find_safetensor_files(dir: &std::path::Path, prefix: &str) -> Result<Vec<std::path::PathBuf>> {
1496 let single = dir.join(format!("{prefix}.safetensors"));
1498 if single.exists() {
1499 return Ok(vec![single]);
1500 }
1501
1502 let mut files: Vec<std::path::PathBuf> = Vec::new();
1504 if let Ok(entries) = std::fs::read_dir(dir) {
1505 for entry in entries.flatten() {
1506 let path = entry.path();
1507 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1508 if name.starts_with(prefix)
1509 && name.ends_with(".safetensors")
1510 && name != format!("{prefix}.safetensors")
1511 {
1512 files.push(path);
1513 }
1514 }
1515 }
1516 }
1517 files.sort();
1518
1519 if files.is_empty() {
1520 Err(FerrumError::model(format!(
1521 "no safetensors files with prefix '{prefix}' in {}",
1522 dir.display()
1523 )))
1524 } else {
1525 Ok(files)
1526 }
1527}
1528
1529fn load_bpe_tokenizer(dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
1531 let tokenizer_json = dir.join("tokenizer.json");
1533 if tokenizer_json.exists() {
1534 return tokenizers::Tokenizer::from_file(&tokenizer_json)
1535 .map_err(|e| FerrumError::model(format!("load tokenizer.json: {e}")));
1536 }
1537
1538 let vocab_path = dir.join("vocab.json");
1540 let merges_path = dir.join("merges.txt");
1541
1542 if !vocab_path.exists() || !merges_path.exists() {
1543 return Err(FerrumError::model(
1544 "tokenizer.json not found, and vocab.json + merges.txt not found either",
1545 ));
1546 }
1547
1548 let vocab_data = std::fs::read_to_string(&vocab_path)
1549 .map_err(|e| FerrumError::model(format!("read vocab.json: {e}")))?;
1550 let vocab: HashMap<String, u32> = serde_json::from_str(&vocab_data)
1551 .map_err(|e| FerrumError::model(format!("parse vocab.json: {e}")))?;
1552
1553 let merges_data = std::fs::read_to_string(&merges_path)
1554 .map_err(|e| FerrumError::model(format!("read merges.txt: {e}")))?;
1555 let merges: Vec<(String, String)> = merges_data
1556 .lines()
1557 .skip(1) .filter(|line| !line.is_empty())
1559 .filter_map(|line| {
1560 let parts: Vec<&str> = line.splitn(2, ' ').collect();
1561 if parts.len() == 2 {
1562 Some((parts[0].to_string(), parts[1].to_string()))
1563 } else {
1564 None
1565 }
1566 })
1567 .collect();
1568
1569 let bpe = tokenizers::models::bpe::BPE::from_file(
1570 vocab_path.to_str().unwrap(),
1571 merges_path.to_str().unwrap(),
1572 )
1573 .build()
1574 .map_err(|e| FerrumError::model(format!("build BPE: {e}")))?;
1575
1576 let tokenizer = tokenizers::Tokenizer::new(bpe);
1577 Ok(tokenizer)
1578}
1579
1580fn logits_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
1582 let logits = if logits.dims().len() == 3 {
1583 logits
1584 .squeeze(0)
1585 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1586 .squeeze(0)
1587 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1588 } else if logits.dims().len() == 2 {
1589 logits
1590 .squeeze(0)
1591 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1592 } else {
1593 logits.clone()
1594 };
1595
1596 logits
1597 .to_vec1()
1598 .map_err(|e| FerrumError::model(format!("logits to_vec1: {e}")))
1599}
1600
1601pub fn sample_token(
1609 logits: &[f32],
1610 temperature: f32,
1611 top_k: usize,
1612 _repetition_penalty: f32,
1613) -> u32 {
1614 if temperature < 0.01 {
1615 return argmax(logits);
1616 }
1617
1618 let vocab = logits.len();
1619
1620 let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
1622
1623 let mut filtered = scaled.clone();
1625 if top_k > 0 && top_k < vocab {
1626 let mut sorted = scaled.clone();
1627 sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1628 let threshold = sorted[top_k - 1];
1629 for v in &mut filtered {
1630 if *v < threshold {
1631 *v = f32::NEG_INFINITY;
1632 }
1633 }
1634 }
1635
1636 const TOP_P: f32 = 0.9;
1638 {
1639 let mut indices: Vec<usize> = (0..vocab).collect();
1640 indices.sort_unstable_by(|&a, &b| {
1641 filtered[b]
1642 .partial_cmp(&filtered[a])
1643 .unwrap_or(std::cmp::Ordering::Equal)
1644 });
1645
1646 let max_val = filtered[indices[0]];
1648 let exp_sorted: Vec<f32> = indices
1649 .iter()
1650 .map(|&i| (filtered[i] - max_val).exp())
1651 .collect();
1652 let sum: f32 = exp_sorted.iter().sum();
1653 let probs_sorted: Vec<f32> = exp_sorted.iter().map(|e| e / sum).collect();
1654
1655 let mut cumsum = 0.0f32;
1657 let mut cutoff_idx = vocab;
1658 for (i, &p) in probs_sorted.iter().enumerate() {
1659 cumsum += p;
1660 if cumsum > TOP_P {
1661 cutoff_idx = i + 1;
1662 break;
1663 }
1664 }
1665
1666 for &idx in &indices[cutoff_idx..] {
1668 filtered[idx] = f32::NEG_INFINITY;
1669 }
1670 }
1671
1672 let max_val = filtered.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1674 let exps: Vec<f32> = filtered.iter().map(|&v| (v - max_val).exp()).collect();
1675 let sum: f32 = exps.iter().sum();
1676 let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
1677
1678 let r = rand_f32();
1680 let mut cumulative = 0.0f32;
1681 for (i, &p) in probs.iter().enumerate() {
1682 cumulative += p;
1683 if cumulative >= r {
1684 return i as u32;
1685 }
1686 }
1687 argmax(&probs)
1689}
1690
1691fn argmax(v: &[f32]) -> u32 {
1692 v.iter()
1693 .enumerate()
1694 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1695 .map(|(i, _)| i as u32)
1696 .unwrap_or(0)
1697}
1698
1699fn rand_f32() -> f32 {
1701 use std::sync::atomic::{AtomicU64, Ordering};
1702 static COUNTER: AtomicU64 = AtomicU64::new(0);
1703
1704 let seed = std::time::SystemTime::now()
1705 .duration_since(std::time::UNIX_EPOCH)
1706 .unwrap_or_default()
1707 .subsec_nanos() as u64;
1708 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1709
1710 let state = seed
1711 .wrapping_add(count)
1712 .wrapping_mul(1103515245)
1713 .wrapping_add(12345);
1714 (state as f32) / (u64::MAX as f32)
1715}
1716
1717#[derive(Clone, Debug)]
1720#[allow(dead_code)]
1721struct DummyTtsCache;
1722
1723impl ferrum_interfaces::KvCacheHandle for DummyTtsCache {
1724 fn block_table(&self) -> &ferrum_interfaces::BlockTable {
1725 static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
1726 std::sync::OnceLock::new();
1727 EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
1728 }
1729 fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
1730 unimplemented!()
1731 }
1732 fn as_any(&self) -> &dyn std::any::Any {
1733 self
1734 }
1735 fn device(&self) -> Device {
1736 Device::CPU
1737 }
1738 fn num_layers(&self) -> usize {
1739 0
1740 }
1741 fn num_heads(&self) -> usize {
1742 0
1743 }
1744 fn head_dim(&self) -> usize {
1745 0
1746 }
1747 fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1748 Ok(None)
1749 }
1750 fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1751 Ok(None)
1752 }
1753 fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
1754 Ok(Arc::new(self.clone()))
1755 }
1756 fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
1757 ferrum_interfaces::CacheHandleStats {
1758 memory_bytes: 0,
1759 blocks_allocated: 0,
1760 tokens_stored: 0,
1761 utilization: 0.0,
1762 last_access: std::time::Instant::now(),
1763 }
1764 }
1765 fn is_valid(&self) -> bool {
1766 true
1767 }
1768 fn cache_id(&self) -> String {
1769 "tts_dummy".to_string()
1770 }
1771}
1772
1773#[async_trait]
1774impl ModelExecutor for TtsModelExecutor {
1775 fn info(&self) -> &ModelInfo {
1776 &self.info
1777 }
1778
1779 async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
1780 Err(FerrumError::model(
1781 "TTS uses synthesize(), not prefill/decode",
1782 ))
1783 }
1784
1785 async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
1786 Err(FerrumError::model(
1787 "TTS uses synthesize(), not prefill/decode",
1788 ))
1789 }
1790
1791 fn capabilities(&self) -> ExecutorCapabilities {
1792 ExecutorCapabilities {
1793 max_batch_size: 1,
1794 max_sequence_length: self.info.max_sequence_length,
1795 attention_mechanisms: vec![AttentionType::GroupedQuery],
1796 supports_dynamic_batching: false,
1797 supports_continuous_batching: false,
1798 supports_speculative_decoding: false,
1799 supports_tensor_parallelism: false,
1800 supports_pipeline_parallelism: false,
1801 supported_dtypes: vec![DataType::FP32, DataType::BF16],
1802 supported_devices: vec![self.info.device.clone()],
1803 memory_requirements: MemoryRequirements {
1804 parameter_memory: 0,
1805 activation_memory_per_token: 0,
1806 kv_cache_memory_per_token: 0,
1807 overhead_memory: 0,
1808 },
1809 }
1810 }
1811
1812 fn release_cache(&self, _: &str) {}
1813
1814 fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
1815 common::default_executor_status()
1816 }
1817}
1818
1819#[cfg(test)]
1820mod tests {
1821 use super::*;
1822
1823 #[test]
1824 fn tts_runtime_env_parses_overrides() {
1825 let env = TtsRuntimeEnv::from_env_vars([
1826 ("FERRUM_TTS_TEMP", "0.7"),
1827 ("FERRUM_ST_TEMP", "0.2"),
1828 ("FERRUM_REF_PCM", "/tmp/ref.pcm"),
1829 ("FERRUM_REF_CODES", "/tmp/ref.codes"),
1830 ("FERRUM_TTS_MIN_FRAMES", "128"),
1831 ]);
1832
1833 assert_eq!(env.tts_temperature, 0.7);
1834 assert_eq!(env.st_temperature(), 0.2);
1835 assert_eq!(env.ref_pcm.as_deref(), Some("/tmp/ref.pcm"));
1836 assert_eq!(env.ref_codes.as_deref(), Some("/tmp/ref.codes"));
1837 assert_eq!(env.min_frames, Some(128));
1838 }
1839
1840 #[test]
1841 fn tts_runtime_env_defaults_invalid_values() {
1842 let env = TtsRuntimeEnv::from_env_vars([
1843 ("FERRUM_TTS_TEMP", "invalid"),
1844 ("FERRUM_ST_TEMP", "invalid"),
1845 ("FERRUM_TTS_MIN_FRAMES", "invalid"),
1846 ]);
1847
1848 assert_eq!(env.tts_temperature, TEMPERATURE);
1849 assert_eq!(env.st_temperature(), TEMPERATURE);
1850 assert_eq!(env.min_frames, None);
1851 }
1852}