1use std::collections::HashMap;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use candle_core::{DType, Device as CandleDevice, Tensor};
11use candle_nn::VarBuilder;
12use ferrum_interfaces::{
13 model_executor::{
14 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
15 PrefillInput, PrefillOutput,
16 },
17 ModelExecutor, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
20use tracing::info;
21
22use super::common;
23use crate::architectures::qwen3_tts::{Qwen3TTSTalker, SubTalker, TalkerConfig};
24use crate::architectures::qwen3_tts_backbone::TalkerBackboneBackend;
25use crate::architectures::qwen3_tts_vocoder::{Qwen3TTSVocoder, VocoderConfig};
26use crate::architectures::speaker_encoder::{mel_spectrogram_speaker_encoder, SpeakerEncoder};
27use crate::architectures::speech_tokenizer_encoder::SpeechTokenizerEncoder;
28use ferrum_quantization::NativeSafetensorsLoader;
29
30#[cfg(feature = "cuda")]
34fn install_cuda_backend_overrides(
35 cfg: &TalkerConfig,
36 model_dir: &std::path::Path,
37 talker: &mut Qwen3TTSTalker,
38 sub_talker: &mut SubTalker,
39) -> Result<()> {
40 use ferrum_kernels::backend::cuda::CudaBackend;
41 let loader: NativeSafetensorsLoader<CudaBackend> = NativeSafetensorsLoader::open(model_dir)?;
42 let talker_bb = TalkerBackboneBackend::<CudaBackend>::new(cfg, &loader)?;
43 talker.set_backend_override(Box::new(talker_bb));
44 let sub_bb = TalkerBackboneBackend::<CudaBackend>::new_code_predictor(cfg, &loader)?;
45 sub_talker.set_backend_override(Box::new(sub_bb));
46 Ok(())
47}
48
49const SAMPLE_RATE: usize = 24000;
52const MAX_CODEC_TOKENS: usize = 2000;
53
54const TEMPERATURE: f32 = 0.9;
57const TOP_K: usize = 50;
58const REPETITION_PENALTY: f32 = 1.05;
59
60fn tts_temperature() -> f32 {
61 std::env::var("FERRUM_TTS_TEMP")
62 .ok()
63 .and_then(|v| v.parse().ok())
64 .unwrap_or(TEMPERATURE)
65}
66
67fn st_temperature() -> f32 {
68 std::env::var("FERRUM_ST_TEMP")
69 .ok()
70 .and_then(|v| v.parse().ok())
71 .unwrap_or(tts_temperature())
72}
73
74pub struct TtsModelExecutor {
76 talker: Qwen3TTSTalker,
77 sub_talker: SubTalker,
78 vocoder: Qwen3TTSVocoder,
79 text_tokenizer: tokenizers::Tokenizer,
80 config: TalkerConfig,
81 info: ModelInfo,
82 speaker_encoder: Option<SpeakerEncoder>,
83 speech_tokenizer_encoder: Option<SpeechTokenizerEncoder>,
84}
85
86impl TtsModelExecutor {
87 pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
93 let dir = std::path::Path::new(model_path);
94
95 let config_json: serde_json::Value = {
97 let config_path = dir.join("config.json");
98 let data = std::fs::read_to_string(&config_path)
99 .map_err(|e| FerrumError::model(format!("read config.json: {e}")))?;
100 serde_json::from_str(&data)
101 .map_err(|e| FerrumError::model(format!("parse config.json: {e}")))?
102 };
103 let config = TalkerConfig::from_json(&config_json)?;
104
105 let text_tokenizer = load_bpe_tokenizer(dir)?;
107
108 let talker_weights = find_safetensor_files(dir, "model")?;
110 let talker_vb = unsafe {
111 VarBuilder::from_mmaped_safetensors(&talker_weights, dtype, &device)
112 .map_err(|e| FerrumError::model(format!("load talker weights: {e}")))?
113 };
114 let mut talker = Qwen3TTSTalker::load(&config, talker_vb.clone(), device.clone())?;
115
116 let mut sub_talker = SubTalker::load(&config, talker_vb.clone(), device.clone())?;
118
119 let spk_enc_dim = config_json
121 .get("speaker_encoder_config")
122 .and_then(|c| c.get("enc_dim"))
123 .and_then(|v| v.as_u64())
124 .unwrap_or(1024) as usize;
125 let speaker_encoder =
126 SpeakerEncoder::load_with_dim(talker_vb.pp("speaker_encoder"), spk_enc_dim)
127 .map_err(|e| {
128 tracing::warn!("Speaker encoder not available: {e}");
129 e
130 })
131 .ok();
132
133 let vocoder_dir = dir.join("speech_tokenizer");
135 let vocoder_weights = find_safetensor_files(&vocoder_dir, "model")?;
136 let vocoder_vb = unsafe {
137 VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &device)
138 .map_err(|e| FerrumError::model(format!("load vocoder weights: {e}")))?
139 };
140 let vocoder_config = VocoderConfig::default();
141 let vocoder = Qwen3TTSVocoder::load(&vocoder_config, vocoder_vb.clone())?;
142
143 let speech_tokenizer_encoder = if vocoder_dir.join("config.json").exists() {
147 let cpu_vb = unsafe {
148 VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &CandleDevice::Cpu)
149 .map_err(|e| FerrumError::model(format!("load encoder cpu: {e}")))?
150 };
151 SpeechTokenizerEncoder::load(cpu_vb.pp("encoder"), CandleDevice::Cpu)
152 .map_err(|e| {
153 tracing::warn!("Speech tokenizer encoder not available: {e}");
154 e
155 })
156 .ok()
157 } else {
158 None
159 };
160
161 #[cfg(feature = "cuda")]
170 if matches!(&device, CandleDevice::Cuda(_)) {
171 match install_cuda_backend_overrides(&config, dir, &mut talker, &mut sub_talker) {
172 Ok(()) => {
173 tracing::info!(
174 "TtsModelExecutor: Backend<CudaBackend> installed for Talker + SubTalker"
175 );
176 }
177 Err(e) => {
178 tracing::warn!(
179 "TtsModelExecutor: Backend<CudaBackend> install failed ({e}); \
180 falling back to candle/fused path (CUDA voice-clone may produce garbage)"
181 );
182 }
183 }
184 }
185
186 let info = ModelInfo {
187 model_id: ferrum_types::ModelId(model_path.to_string()),
188 model_type: ModelType::Custom("qwen3-tts".to_string()),
189 hidden_size: config.hidden_size,
190 vocab_size: config.vocab_size,
191 num_layers: config.num_hidden_layers,
192 num_heads: config.num_attention_heads,
193 num_kv_heads: config.num_key_value_heads,
194 num_parameters: 0,
195 max_sequence_length: config.max_position_embeddings,
196 device: match &device {
197 CandleDevice::Cpu => Device::CPU,
198 CandleDevice::Cuda(_) => Device::CUDA(0),
199 #[cfg(any(target_os = "macos", target_os = "ios"))]
200 CandleDevice::Metal(_) => Device::Metal,
201 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
202 CandleDevice::Metal(_) => Device::CPU,
203 },
204 dtype: match dtype {
205 DType::F32 => DataType::FP32,
206 DType::F16 => DataType::FP16,
207 DType::BF16 => DataType::BF16,
208 _ => DataType::FP32,
209 },
210 version: None,
211 license: None,
212 metadata: HashMap::new(),
213 };
214
215 info!(
216 "TtsModelExecutor: {} (hidden={}, layers={}, codec_groups={})",
217 model_path, config.hidden_size, config.num_hidden_layers, config.num_code_groups,
218 );
219
220 Ok(Self {
221 talker,
222 sub_talker,
223 vocoder,
224 text_tokenizer,
225 config,
226 info,
227 speaker_encoder,
228 speech_tokenizer_encoder,
229 })
230 }
231
232 pub fn synthesize(&mut self, text: &str, language: &str) -> Result<Vec<f32>> {
240 self.talker.reset();
241
242 let device = self.talker.device().clone();
243
244 let encoding = self
246 .text_tokenizer
247 .encode(text, false)
248 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
249 let content_ids: Vec<u32> = encoding.get_ids().to_vec();
250
251 if content_ids.is_empty() {
252 return Err(FerrumError::model("empty text after tokenization"));
253 }
254
255 info!("TTS: content tokens = {}", content_ids.len());
256
257 let codec_eos = self.config.codec_eos_token_id;
258 let tts_pad = self.config.tts_pad_token_id;
259 let tts_bos = self.config.tts_bos_token_id;
260 let tts_eos = self.config.tts_eos_token_id;
261
262 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
264 let t = Tensor::new(ids, &device)
265 .and_then(|t| t.unsqueeze(0))
266 .map_err(|e| FerrumError::model(format!("text ids: {e}")))?;
267 self.talker.embed_text(&t)
268 };
269 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
270 let t = Tensor::new(ids, &device)
271 .and_then(|t| t.unsqueeze(0))
272 .map_err(|e| FerrumError::model(format!("codec ids: {e}")))?;
273 self.talker.embed_codec(&t)
274 };
275
276 let im_start_id = 151644u32; let assistant_id = 77091u32; let newline_id = 198u32; let role_prefix_ids = [im_start_id, assistant_id, newline_id];
284
285 info!(
286 "TTS: role_prefix={} content={} tokens",
287 role_prefix_ids.len(),
288 content_ids.len()
289 );
290
291 let role_embed = embed_text_ids(&role_prefix_ids)?;
293
294 let resolved_lang = if language == "auto" {
296 "chinese"
297 } else {
298 language
299 };
300 let language_id = self
301 .config
302 .codec_language_id
303 .get(&resolved_lang.to_lowercase());
304 let codec_prefix_ids = if let Some(&lang_id) = language_id {
305 vec![
306 self.config.codec_think_id,
307 self.config.codec_think_bos_id,
308 lang_id,
309 self.config.codec_think_eos_id,
310 ]
311 } else {
312 vec![
313 self.config.codec_nothink_id,
314 self.config.codec_think_bos_id,
315 self.config.codec_think_eos_id,
316 ]
317 };
318 let speaker_token = if resolved_lang == "chinese" {
321 3065u32
322 } else {
323 3061u32
324 };
325 let codec_full = {
326 let mut v = codec_prefix_ids.clone();
327 v.push(speaker_token);
328 v.push(self.config.codec_pad_id);
329 v.push(self.config.codec_bos_id);
330 v
331 };
332 let codec_embed = embed_codec_ids(&codec_full)?;
333
334 let n_codec = codec_full.len();
336 let mut tts_prefix_ids = vec![tts_pad; n_codec - 1];
337 tts_prefix_ids.push(tts_bos);
338 let tts_prefix_embed = embed_text_ids(&tts_prefix_ids)?;
339
340 let codec_first = codec_embed
342 .narrow(1, 0, n_codec - 1)
343 .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
344 let n_prefix = n_codec - 1; let codec_prefix_part = codec_embed
353 .narrow(1, 0, n_prefix)
354 .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
355
356 let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
358 tts_text_prefix_ids.push(tts_bos);
359 let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
360
361 let codec_hidden = (&tts_text_embed + &codec_prefix_part)
362 .map_err(|e| FerrumError::model(format!("prefix sum: {e}")))?;
363
364 let codec_bos_embed = codec_embed
366 .narrow(1, n_prefix, 1)
367 .map_err(|e| FerrumError::model(format!("codec bos: {e}")))?;
368
369 let first_text_combined = if !content_ids.is_empty() {
371 let first_text_embed = embed_text_ids(&content_ids[..1])?;
372 (&first_text_embed + &codec_bos_embed)
373 .map_err(|e| FerrumError::model(format!("first text+bos: {e}")))?
374 } else {
375 codec_bos_embed.clone()
376 };
377
378 let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
380 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
381
382 let plen = prefill_embeds.dim(1).unwrap_or(0);
383 info!("TTS: prefill_len = {}", plen);
384 if let Ok(v) = prefill_embeds
386 .narrow(0, 0, 1)
387 .and_then(|t| t.narrow(1, 0, 1))
388 .and_then(|t| t.narrow(2, 0, 5))
389 .and_then(|t| t.flatten_all())
390 .and_then(|t| t.to_vec1::<f32>())
391 {
392 info!(" prefill_input pos0[:5] = {:?}", v);
393 }
394 if plen > 0 {
395 if let Ok(v) = prefill_embeds
396 .narrow(0, 0, 1)
397 .and_then(|t| t.narrow(1, plen - 1, 1))
398 .and_then(|t| t.narrow(2, 0, 5))
399 .and_then(|t| t.flatten_all())
400 .and_then(|t| t.to_vec1::<f32>())
401 {
402 info!(" prefill_input pos-1[:5] = {:?}", v);
403 }
404 }
405
406 let mut trailing_ids: Vec<u32> = if content_ids.len() > 1 {
408 content_ids[1..].to_vec()
409 } else {
410 Vec::new()
411 };
412 trailing_ids.push(tts_eos);
413 let trailing_text_embeds = embed_text_ids(&trailing_ids)?;
414 let trailing_text_len = trailing_text_embeds
415 .dim(1)
416 .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
417 let tts_pad_embed = embed_text_ids(&[tts_pad])?;
418
419 info!("TTS: trailing_text_len = {}", trailing_text_len);
420
421 let mut hidden = self.talker.forward_step(&prefill_embeds)?;
423
424 let current_logits = self.talker.logits(
426 &hidden
427 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
428 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
429 )?;
430
431 let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
433 let mut current_logits = current_logits;
434
435 let suppress_start = self.config.vocab_size.saturating_sub(1024);
437 let suppress_end = self.config.vocab_size;
438 let mut generated_tokens: Vec<u32> = Vec::new();
439
440 for step in 0..MAX_CODEC_TOKENS {
441 let mut logits_vec = logits_to_vec(¤t_logits)?;
443 for i in suppress_start..suppress_end.min(logits_vec.len()) {
445 if i as u32 != codec_eos {
446 logits_vec[i] = f32::NEG_INFINITY;
447 }
448 }
449 for &prev_tok in &generated_tokens {
451 let idx = prev_tok as usize;
452 if idx < logits_vec.len() {
453 if logits_vec[idx] > 0.0 {
454 logits_vec[idx] /= REPETITION_PENALTY;
455 } else {
456 logits_vec[idx] *= REPETITION_PENALTY;
457 }
458 }
459 }
460 let next_token =
461 sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
462
463 if step < 10 {
464 let argmax_tok = logits_vec
466 .iter()
467 .enumerate()
468 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
469 .map(|(i, v)| (i, *v))
470 .unwrap_or((0, 0.0));
471 info!(
472 "TOKEN step={} sampled={} argmax=({}, {:.2})",
473 step, next_token, argmax_tok.0, argmax_tok.1
474 );
475 }
476
477 generated_tokens.push(next_token);
478
479 if next_token == codec_eos {
481 info!("TTS: codec EOS at step {}", step);
482 break;
483 }
484
485 let last_hidden = hidden
487 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
488 .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
489
490 let token_tensor = Tensor::new(&[next_token], &device)
492 .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
493 .unsqueeze(0)
494 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
495 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
496
497 let st_t0 = std::time::Instant::now();
499 let extra_codes = self.sub_talker.predict(
500 &last_hidden,
501 &first_codec_embed,
502 st_temperature(),
503 TOP_K,
504 )?;
505 if step == 0 {
506 info!(
507 " SubTalker: {:.1}ms",
508 st_t0.elapsed().as_secs_f64() * 1000.0
509 );
510 }
511
512 let mut frame_codes = vec![next_token];
513 frame_codes.extend_from_slice(&extra_codes);
514 all_codec_tokens.push(frame_codes);
515
516 let mut combined_embed = first_codec_embed.clone();
519 for (i, &code) in extra_codes.iter().enumerate() {
520 let code_t = Tensor::new(&[code], &device)
521 .and_then(|t| t.unsqueeze(0))
522 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
523 let sub_embed = code_t
524 .apply(&self.sub_talker.codec_embeddings[i])
525 .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
526 combined_embed = (combined_embed + sub_embed)
527 .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
528 }
529
530 if step == 0 {
531 if let Ok(v) = combined_embed
532 .flatten_all()
533 .and_then(|t| t.narrow(0, 0, 5))
534 .and_then(|t| t.to_vec1::<f32>())
535 {
536 info!("STEP0 codec_sum[:5] = {:?} (before trailing)", v);
537 }
538 }
539 if step < trailing_text_len {
544 let trail = trailing_text_embeds
545 .narrow(1, step, 1)
546 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
547 combined_embed = (combined_embed + trail)
548 .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
549 } else {
550 combined_embed = (combined_embed + &tts_pad_embed)
551 .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
552 }
553
554 if step == 0 {
556 if let Ok(v) = first_codec_embed
557 .flatten_all()
558 .and_then(|t| t.narrow(0, 0, 5))
559 .and_then(|t| t.to_vec1::<f32>())
560 {
561 info!("STEP0 semantic[:5] = {:?}", v);
562 }
563 if let Ok(v) = combined_embed
564 .flatten_all()
565 .and_then(|t| t.narrow(0, 0, 5))
566 .and_then(|t| t.to_vec1::<f32>())
567 {
568 info!("STEP0 combined[:5] = {:?}", v);
569 }
570 }
571 let tk_t0 = std::time::Instant::now();
573 hidden = self.talker.forward_step(&combined_embed)?;
574 current_logits = self.talker.logits(&hidden)?;
575 if step == 0 {
576 info!(
577 " Talker step: {:.1}ms",
578 tk_t0.elapsed().as_secs_f64() * 1000.0
579 );
580 }
581 }
582
583 if all_codec_tokens.is_empty() {
584 return Err(FerrumError::model("no codec tokens generated"));
585 }
586
587 info!("TTS: generated {} codec frames", all_codec_tokens.len());
588
589 let num_frames = all_codec_tokens.len();
591 let num_groups = self.config.num_code_groups;
592 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
593 for (t, frame) in all_codec_tokens.iter().enumerate() {
594 for (g, &code) in frame.iter().enumerate() {
595 flat_codes[g * num_frames + t] = code;
596 }
597 }
598
599 let codebook_size = 2048u32;
602 for code in &mut flat_codes {
603 if *code >= codebook_size {
604 *code = 0; }
606 }
607
608 let codes_tensor = Tensor::new(&flat_codes[..], &device)
609 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
610 .reshape((1, num_groups, num_frames))
611 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
612
613 let waveform = self.vocoder.decode(&codes_tensor)?;
615
616 let samples: Vec<f32> = waveform
618 .squeeze(0)
619 .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
620 .squeeze(0)
621 .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
622 .to_vec1()
623 .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
624
625 info!(
626 "TTS: waveform {} samples ({:.2}s @ {}Hz)",
627 samples.len(),
628 samples.len() as f64 / SAMPLE_RATE as f64,
629 SAMPLE_RATE,
630 );
631
632 Ok(samples)
633 }
634
635 fn decode_frames(&mut self, frames: &[Vec<u32>], device: &CandleDevice) -> Result<Vec<f32>> {
638 let num_frames = frames.len();
639 if num_frames == 0 {
640 return Ok(vec![]);
641 }
642 let num_groups = self.config.num_code_groups;
643 let codebook_size = 2048u32;
644
645 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
646 for (t, frame) in frames.iter().enumerate() {
647 for (g, &code) in frame.iter().take(num_groups).enumerate() {
648 flat_codes[g * num_frames + t] = if code >= codebook_size { 0 } else { code };
649 }
650 }
651
652 let codes_tensor = Tensor::new(&flat_codes[..], device)
653 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
654 .reshape((1, num_groups, num_frames))
655 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
656
657 let waveform = self.vocoder.decode(&codes_tensor)?;
658 waveform
659 .squeeze(0)
660 .and_then(|t| t.squeeze(0))
661 .and_then(|t| t.to_vec1())
662 .map_err(|e| FerrumError::model(format!("waveform extract: {e}")))
663 }
664
665 pub fn synthesize_streaming<F: FnMut(usize, &[f32])>(
670 &mut self,
671 text: &str,
672 language: &str,
673 chunk_frames: usize,
674 mut on_chunk: F,
675 ) -> Result<Vec<Vec<f32>>> {
676 self.talker.reset();
679 let device = self.talker.device().clone();
680
681 let encoding = self
682 .text_tokenizer
683 .encode(text, false)
684 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
685 let content_ids: Vec<u32> = encoding.get_ids().to_vec();
686 if content_ids.is_empty() {
687 return Err(FerrumError::model("empty text after tokenization"));
688 }
689
690 let codec_eos = self.config.codec_eos_token_id;
691 let tts_pad = self.config.tts_pad_token_id;
692 let tts_bos = self.config.tts_bos_token_id;
693 let tts_eos = self.config.tts_eos_token_id;
694
695 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
697 let t = Tensor::new(ids, &device)
698 .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
699 .unsqueeze(0)
700 .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
701 self.talker.embed_text(&t)
702 };
703 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
704 let t = Tensor::new(ids, &device)
705 .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
706 .unsqueeze(0)
707 .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
708 self.talker.embed_codec(&t)
709 };
710
711 let resolved_lang = if language.eq_ignore_ascii_case("auto") {
713 "chinese"
714 } else {
715 language
716 };
717 let language_id = self
718 .config
719 .codec_language_id
720 .get(&resolved_lang.to_lowercase());
721 let codec_prefix_ids = if let Some(&lang_id) = language_id {
722 vec![
723 self.config.codec_think_id,
724 self.config.codec_think_bos_id,
725 lang_id,
726 self.config.codec_think_eos_id,
727 ]
728 } else {
729 vec![
730 self.config.codec_nothink_id,
731 self.config.codec_think_bos_id,
732 self.config.codec_think_eos_id,
733 ]
734 };
735 let speaker_token = if resolved_lang == "chinese" {
736 3065u32
737 } else {
738 3061u32
739 };
740 let mut codec_ids = codec_prefix_ids;
741 codec_ids.push(speaker_token);
742 codec_ids.push(self.config.codec_pad_id);
743 codec_ids.push(self.config.codec_bos_id);
744 let codec_embed = embed_codec_ids(&codec_ids)?;
745 let n_codec = codec_embed
746 .dim(1)
747 .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
748 let n_prefix = n_codec - 1;
749 let codec_prefix_part = codec_embed
750 .narrow(1, 0, n_prefix)
751 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
752
753 let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
754 tts_text_prefix_ids.push(tts_bos);
755 let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
756 let codec_hidden = (&tts_text_embed + &codec_prefix_part)
757 .map_err(|e| FerrumError::model(format!("sum: {e}")))?;
758 let codec_bos_embed = codec_embed
759 .narrow(1, n_prefix, 1)
760 .map_err(|e| FerrumError::model(format!("bos: {e}")))?;
761
762 let role_ids: &[u32] = &[151644, 77091, 198]; let role_embed = embed_text_ids(role_ids)?;
764
765 let first_text_combined = if !content_ids.is_empty() {
766 let first_text_embed = embed_text_ids(&content_ids[..1])?;
767 (&first_text_embed + &codec_bos_embed)
768 .map_err(|e| FerrumError::model(format!("first: {e}")))?
769 } else {
770 codec_bos_embed.clone()
771 };
772
773 let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
774 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
775
776 let trailing_text_embeds = if content_ids.len() > 1 {
778 let remaining = embed_text_ids(&content_ids[1..])?;
779 let eos = embed_text_ids(&[tts_eos])?;
780 Tensor::cat(&[&remaining, &eos], 1)
781 .map_err(|e| FerrumError::model(format!("trailing: {e}")))?
782 } else {
783 embed_text_ids(&[tts_eos])?
784 };
785 let trailing_text_len = trailing_text_embeds
786 .dim(1)
787 .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
788 let tts_pad_embed = embed_text_ids(&[tts_pad])?;
789
790 let mut hidden = self.talker.forward_step(&prefill_embeds)?;
792 let mut current_logits = self.talker.logits(
793 &hidden
794 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
795 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
796 )?;
797
798 let suppress_start = self.config.vocab_size.saturating_sub(1024);
800 let suppress_end = self.config.vocab_size;
801 let mut generated_tokens: Vec<u32> = Vec::new();
802 let mut frame_buffer: Vec<Vec<u32>> = Vec::new();
803 let mut audio_chunks: Vec<Vec<f32>> = Vec::new();
804
805 for step in 0..MAX_CODEC_TOKENS {
806 let mut logits_vec = logits_to_vec(¤t_logits)?;
807 for i in suppress_start..suppress_end.min(logits_vec.len()) {
808 if i as u32 != codec_eos {
809 logits_vec[i] = f32::NEG_INFINITY;
810 }
811 }
812 for &prev_tok in &generated_tokens {
813 let idx = prev_tok as usize;
814 if idx < logits_vec.len() {
815 if logits_vec[idx] > 0.0 {
816 logits_vec[idx] /= REPETITION_PENALTY;
817 } else {
818 logits_vec[idx] *= REPETITION_PENALTY;
819 }
820 }
821 }
822 let next_token =
823 sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
824 generated_tokens.push(next_token);
825
826 if next_token == codec_eos {
827 info!("TTS streaming: EOS at step {}", step);
828 break;
829 }
830
831 let last_hidden = hidden
832 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
833 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
834 let token_tensor = Tensor::new(&[next_token], &device)
835 .and_then(|t| t.unsqueeze(0))
836 .map_err(|e| FerrumError::model(format!("tok: {e}")))?;
837 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
838
839 let extra_codes = self.sub_talker.predict(
840 &last_hidden,
841 &first_codec_embed,
842 st_temperature(),
843 TOP_K,
844 )?;
845
846 let mut frame = vec![next_token];
847 frame.extend_from_slice(&extra_codes);
848 frame_buffer.push(frame);
849
850 if frame_buffer.len() >= chunk_frames {
852 let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
853 on_chunk(audio_chunks.len(), &chunk_audio);
854 audio_chunks.push(chunk_audio);
855 frame_buffer.clear();
856 }
857
858 let mut combined_embed = first_codec_embed.clone();
860 for (i, &code) in extra_codes.iter().enumerate() {
861 let code_t = Tensor::new(&[code], &device)
862 .and_then(|t| t.unsqueeze(0))
863 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
864 let sub_embed = code_t
865 .apply(&self.sub_talker.codec_embeddings[i])
866 .map_err(|e| FerrumError::model(format!("sub: {e}")))?;
867 combined_embed = (combined_embed + sub_embed)
868 .map_err(|e| FerrumError::model(format!("add: {e}")))?;
869 }
870 if step < trailing_text_len {
871 let trail = trailing_text_embeds
872 .narrow(1, step, 1)
873 .map_err(|e| FerrumError::model(format!("trail: {e}")))?;
874 combined_embed = (combined_embed + trail)
875 .map_err(|e| FerrumError::model(format!("add trail: {e}")))?;
876 } else {
877 combined_embed = (combined_embed + &tts_pad_embed)
878 .map_err(|e| FerrumError::model(format!("pad: {e}")))?;
879 }
880
881 hidden = self.talker.forward_step(&combined_embed)?;
882 current_logits = self.talker.logits(&hidden)?;
883 }
884
885 if !frame_buffer.is_empty() {
887 let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
888 on_chunk(audio_chunks.len(), &chunk_audio);
889 audio_chunks.push(chunk_audio);
890 }
891
892 Ok(audio_chunks)
893 }
894
895 pub fn sample_rate(&self) -> usize {
897 SAMPLE_RATE
898 }
899
900 pub fn config(&self) -> &TalkerConfig {
901 &self.config
902 }
903
904 pub fn synthesize_voice_clone(
912 &mut self,
913 text: &str,
914 language: &str,
915 ref_audio_path: &str,
916 ref_text: &str,
917 ) -> Result<Vec<f32>> {
918 let device = self.talker.device().clone();
919
920 let ref_pcm = if let Ok(path) = std::env::var("FERRUM_REF_PCM") {
922 let data = std::fs::read(&path)
923 .map_err(|e| FerrumError::model(format!("read ref pcm: {e}")))?;
924 let pcm: Vec<f32> = data
925 .chunks(4)
926 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
927 .collect();
928 info!("Loaded ref PCM override: {} samples", pcm.len());
929 pcm
930 } else {
931 crate::audio_processor::load_audio_at_rate(ref_audio_path, 24000)?
932 };
933 info!(
934 "TTS voice clone: loaded ref audio {} samples ({:.2}s)",
935 ref_pcm.len(),
936 ref_pcm.len() as f64 / 24000.0
937 );
938
939 let t0 = std::time::Instant::now();
940 let speaker_encoder = self
942 .speaker_encoder
943 .as_ref()
944 .ok_or_else(|| FerrumError::model("speaker encoder not loaded"))?;
945 let mel = mel_spectrogram_speaker_encoder(&ref_pcm);
946 let n_mel_frames = mel.len() / 128;
947 let mel_tensor = Tensor::from_vec(mel, (1, n_mel_frames, 128), &device)
950 .map_err(|e| FerrumError::model(format!("mel tensor: {e}")))?;
951 let spk_embed = speaker_encoder.forward(&mel_tensor)?;
952 info!(
953 "Step 2 (speaker embed): {:.1}ms",
954 t0.elapsed().as_secs_f64() * 1000.0
955 );
956 let spk_embed = spk_embed
958 .unsqueeze(0)
959 .map_err(|e| FerrumError::model(format!("spk unsqueeze(0): {e}")))?
960 .unsqueeze(0)
961 .map_err(|e| FerrumError::model(format!("spk unsqueeze(0) 2: {e}")))?;
962
963 let t1 = std::time::Instant::now();
964 let speech_enc = self
966 .speech_tokenizer_encoder
967 .as_ref()
968 .ok_or_else(|| FerrumError::model("speech tokenizer encoder not loaded"))?;
969 let ref_codes = if let Ok(path) = std::env::var("FERRUM_REF_CODES") {
971 let data = std::fs::read(&path)
972 .map_err(|e| FerrumError::model(format!("read ref codes: {e}")))?;
973 let u32s: Vec<u32> = data
974 .chunks(4)
975 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
976 .collect();
977 let ncb = self.config.num_code_groups;
978 let nframes = u32s.len() / ncb;
979 info!(
980 "Loaded pre-computed ref codes: {} frames from {}",
981 nframes, path
982 );
983 u32s.chunks(ncb).map(|c| c.to_vec()).collect()
984 } else {
985 let codes = speech_enc.encode(&ref_pcm)?;
986 info!(
987 "Step 3 (speech tokenizer): {:.1}ms",
988 t1.elapsed().as_secs_f64() * 1000.0
989 );
990 codes
991 };
992 let ref_frames = ref_codes.len();
993 info!(
994 "TTS voice clone: ref_frames={}, spk_embed loaded",
995 ref_frames
996 );
997 for i in 0..ref_frames.min(5) {
999 info!(" rust codec frame {}: {:?}", i, &ref_codes[i]);
1000 }
1001
1002 info!(
1003 "Step 3 (speech tokenizer): {:.1}ms",
1004 t1.elapsed().as_secs_f64() * 1000.0
1005 );
1006 let t2 = std::time::Instant::now();
1007 let chat_text = format!("<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n");
1009 let encoding = self
1010 .text_tokenizer
1011 .encode(chat_text.as_str(), false)
1012 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
1013 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
1014 let role_ids = &input_ids[..3];
1016 let text_content_ids = &input_ids[3..input_ids.len().saturating_sub(5)];
1017
1018 let ref_chat_text = format!("<|im_start|>assistant\n{ref_text}<|im_end|>\n");
1020 let ref_encoding = self
1021 .text_tokenizer
1022 .encode(ref_chat_text.as_str(), false)
1023 .map_err(|e| FerrumError::model(format!("tokenize ref: {e}")))?;
1024 let ref_ids: Vec<u32> = ref_encoding.get_ids().to_vec();
1025 let ref_text_ids = &ref_ids[3..ref_ids.len().saturating_sub(2)];
1027
1028 self.talker.reset();
1030
1031 let tts_bos = self.config.tts_bos_token_id;
1032 let tts_eos = self.config.tts_eos_token_id;
1033 let tts_pad = self.config.tts_pad_token_id;
1034 let codec_bos = self.config.codec_bos_id;
1035 let codec_eos = self.config.codec_eos_token_id;
1036 let codec_pad = self.config.codec_pad_id;
1037
1038 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
1040 let t = Tensor::new(ids, &device)
1041 .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
1042 .unsqueeze(0)
1043 .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
1044 self.talker.embed_codec(&t)
1045 };
1046 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
1047 let t = Tensor::new(ids, &device)
1048 .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
1049 .unsqueeze(0)
1050 .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
1051 self.talker.embed_text(&t)
1052 };
1053
1054 let tts_special = embed_text_ids(&[tts_bos, tts_eos, tts_pad])?;
1056 let tts_bos_embed = tts_special
1057 .narrow(1, 0, 1)
1058 .map_err(|e| FerrumError::model(format!("tts_bos narrow: {e}")))?;
1059 let tts_eos_embed = tts_special
1060 .narrow(1, 1, 1)
1061 .map_err(|e| FerrumError::model(format!("tts_eos narrow: {e}")))?;
1062 let tts_pad_embed = tts_special
1063 .narrow(1, 2, 1)
1064 .map_err(|e| FerrumError::model(format!("tts_pad narrow: {e}")))?;
1065
1066 let resolved_lang = if language.eq_ignore_ascii_case("auto") {
1068 "chinese"
1069 } else {
1070 language
1071 };
1072 let language_id = self
1073 .config
1074 .codec_language_id
1075 .get(&resolved_lang.to_lowercase());
1076
1077 let codec_prefix_ids = if let Some(&lang_id) = language_id {
1079 vec![
1080 self.config.codec_think_id,
1081 self.config.codec_think_bos_id,
1082 lang_id,
1083 self.config.codec_think_eos_id,
1084 ]
1085 } else {
1086 vec![
1087 self.config.codec_nothink_id,
1088 self.config.codec_think_bos_id,
1089 self.config.codec_think_eos_id,
1090 ]
1091 };
1092 let codec_prefix_embed = embed_codec_ids(&codec_prefix_ids)?;
1093
1094 let codec_suffix_embed = embed_codec_ids(&[codec_pad, codec_bos])?;
1096
1097 let codec_input = Tensor::cat(&[&codec_prefix_embed, &spk_embed, &codec_suffix_embed], 1)
1099 .map_err(|e| FerrumError::model(format!("codec_input cat: {e}")))?;
1100 let codec_len = codec_input
1101 .dim(1)
1102 .map_err(|e| FerrumError::model(format!("codec_len dim: {e}")))?;
1103
1104 let role_embed = embed_text_ids(role_ids)?;
1106
1107 let n_pads = codec_len - 2;
1109 let mut text_prefix_parts = Vec::new();
1110 for _ in 0..n_pads {
1111 text_prefix_parts.push(tts_pad_embed.clone());
1112 }
1113 text_prefix_parts.push(tts_bos_embed.clone());
1114 let text_prefix_refs: Vec<&Tensor> = text_prefix_parts.iter().collect();
1115 let text_prefix = Tensor::cat(&text_prefix_refs, 1)
1116 .map_err(|e| FerrumError::model(format!("text_prefix cat: {e}")))?;
1117 let codec_prefix_part = codec_input
1118 .narrow(1, 0, codec_len - 1)
1119 .map_err(|e| FerrumError::model(format!("codec prefix narrow: {e}")))?;
1120 let text_codec_prefix = (&text_prefix + &codec_prefix_part)
1121 .map_err(|e| FerrumError::model(format!("text+codec prefix sum: {e}")))?;
1122
1123 let prefill_embed = Tensor::cat(&[&role_embed, &text_codec_prefix], 1)
1125 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
1126
1127 let t3 = std::time::Instant::now();
1128
1129 let all_text_ids: Vec<u32> = ref_text_ids
1132 .iter()
1133 .chain(text_content_ids.iter())
1134 .copied()
1135 .collect();
1136 let text_embed = embed_text_ids(&all_text_ids)?;
1137 let text_embed_with_eos = Tensor::cat(&[&text_embed, &tts_eos_embed], 1)
1138 .map_err(|e| FerrumError::model(format!("text+eos cat: {e}")))?;
1139 let text_len = text_embed_with_eos
1140 .dim(1)
1141 .map_err(|e| FerrumError::model(format!("text_len dim: {e}")))?;
1142
1143 let t_codec_start = std::time::Instant::now();
1146 let ncg = self.config.num_code_groups;
1147 let first_codes: Vec<u32> = ref_codes.iter().map(|f| f[0]).collect();
1148 let codec_frames_cat = {
1149 let mut sum = embed_codec_ids(&first_codes)?;
1151 for cb in 0..(ncg - 1) {
1153 let codes: Vec<u32> = ref_codes.iter().map(|f| f[cb + 1]).collect();
1154 let codes_t = Tensor::new(codes.as_slice(), &device)
1155 .map_err(|e| FerrumError::model(format!("batch codes: {e}")))?
1156 .unsqueeze(0)
1157 .map_err(|e| FerrumError::model(format!("batch unsqueeze: {e}")))?;
1158 let sub_embed = codes_t
1159 .apply(&self.sub_talker.codec_embeddings[cb])
1160 .map_err(|e| FerrumError::model(format!("batch sub_embed: {e}")))?;
1161 sum =
1162 (sum + sub_embed).map_err(|e| FerrumError::model(format!("batch add: {e}")))?;
1163 }
1164 sum
1165 };
1166 info!(
1167 "Codec embedding: {:.1}ms ({} frames × {} codebooks)",
1168 t_codec_start.elapsed().as_secs_f64() * 1000.0,
1169 ref_codes.len(),
1170 ncg
1171 );
1172
1173 let t_merge = std::time::Instant::now();
1175 let codec_bos_for_icl = embed_codec_ids(&[codec_bos])?;
1176 let icl_codec = Tensor::cat(&[&codec_bos_for_icl, &codec_frames_cat], 1)
1177 .map_err(|e| FerrumError::model(format!("icl_codec cat: {e}")))?;
1178 let codec_icl_len = icl_codec
1179 .dim(1)
1180 .map_err(|e| FerrumError::model(format!("codec_icl_len dim: {e}")))?;
1181
1182 let icl_trailing: Tensor;
1184 let icl_embed: Tensor;
1185 if text_len > codec_icl_len {
1186 let text_part = text_embed_with_eos
1187 .narrow(1, 0, codec_icl_len)
1188 .map_err(|e| FerrumError::model(format!("text_part narrow: {e}")))?;
1189 icl_embed = (&text_part + &icl_codec)
1190 .map_err(|e| FerrumError::model(format!("text+codec sum: {e}")))?;
1191 icl_trailing = text_embed_with_eos
1192 .narrow(1, codec_icl_len, text_len - codec_icl_len)
1193 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1194 } else {
1195 let n_pad = codec_icl_len - text_len;
1196 let text_padded = if n_pad > 0 {
1197 let pad_block = tts_pad_embed
1198 .expand((1, n_pad, self.config.hidden_size))
1199 .map_err(|e| FerrumError::model(format!("pad expand: {e}")))?;
1200 Tensor::cat(&[&text_embed_with_eos, &pad_block], 1)
1201 .map_err(|e| FerrumError::model(format!("text_padded cat: {e}")))?
1202 } else {
1203 text_embed_with_eos.clone()
1204 };
1205 icl_embed = (&text_padded + &icl_codec)
1206 .map_err(|e| FerrumError::model(format!("padded+codec sum: {e}")))?;
1207 icl_trailing = tts_pad_embed.clone();
1208 }
1209 let trailing_text_len = icl_trailing
1210 .dim(1)
1211 .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
1212
1213 let _prefill_out = self.talker.forward_step(&prefill_embed)?;
1216 let t_icl = std::time::Instant::now();
1217 let icl_hidden = self.talker.forward_step(&icl_embed)?;
1218 let icl_len = icl_hidden
1219 .dim(1)
1220 .map_err(|e| FerrumError::model(format!("icl_hidden dim: {e}")))?;
1221 info!(
1222 "ICL block: {:.1}ms ({} tokens), trailing={}",
1223 t_icl.elapsed().as_secs_f64() * 1000.0,
1224 icl_len,
1225 trailing_text_len
1226 );
1227
1228 let mut hidden = icl_hidden;
1230 let hidden_len = hidden
1231 .dim(1)
1232 .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1233 let last_hidden = hidden
1234 .narrow(1, hidden_len - 1, 1)
1235 .map_err(|e| FerrumError::model(format!("narrow last: {e}")))?;
1236 if let Ok(v) = last_hidden.flatten_all().and_then(|t| t.to_vec1::<f32>()) {}
1237 let current_logits = self.talker.logits(&last_hidden)?;
1238 {}
1239
1240 let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
1242 let mut current_logits = current_logits;
1243
1244 let suppress_start = self.config.vocab_size.saturating_sub(1024);
1246 let suppress_end = self.config.vocab_size;
1247
1248 const ICL_REPETITION_PENALTY: f32 = 1.5;
1251 const ICL_FRAMES_PER_TOKEN: usize = 6;
1252 const ICL_MIN_FRAMES: usize = 75;
1253 let max_icl_tokens = ICL_MIN_FRAMES.max(text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1254 let mut generated_tokens: Vec<u32> = Vec::new();
1255
1256 for step in 0..max_icl_tokens {
1257 let mut logits_vec = logits_to_vec(¤t_logits)?;
1258 for i in suppress_start..suppress_end.min(logits_vec.len()) {
1260 if i as u32 != codec_eos {
1261 logits_vec[i] = f32::NEG_INFINITY;
1262 }
1263 }
1264 let min_frames: usize = std::env::var("FERRUM_TTS_MIN_FRAMES")
1268 .ok()
1269 .and_then(|v| v.parse().ok())
1270 .unwrap_or_else(|| text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1271 if step < min_frames {
1272 if let Some(v) = logits_vec.get_mut(codec_eos as usize) {
1273 *v = f32::NEG_INFINITY;
1274 }
1275 }
1276 for &prev_tok in &generated_tokens {
1278 let idx = prev_tok as usize;
1279 if idx < logits_vec.len() {
1280 if logits_vec[idx] > 0.0 {
1281 logits_vec[idx] /= ICL_REPETITION_PENALTY;
1282 } else {
1283 logits_vec[idx] *= ICL_REPETITION_PENALTY;
1284 }
1285 }
1286 }
1287 let next_token = sample_token(
1288 &logits_vec,
1289 tts_temperature(),
1290 TOP_K,
1291 ICL_REPETITION_PENALTY,
1292 );
1293
1294 generated_tokens.push(next_token);
1295
1296 if next_token == codec_eos {
1297 info!("TTS voice clone: codec EOS at step {}", step);
1298 break;
1299 }
1300
1301 if generated_tokens.len() >= 6 {
1303 let n = generated_tokens.len();
1304 let mut is_repeat = false;
1305 for pat_len in 1..=4 {
1306 if n >= pat_len * 3 {
1307 let a = &generated_tokens[n - pat_len * 3..n - pat_len * 2];
1308 let b = &generated_tokens[n - pat_len * 2..n - pat_len];
1309 let c = &generated_tokens[n - pat_len..n];
1310 if a == b && b == c {
1311 is_repeat = true;
1312 break;
1313 }
1314 }
1315 }
1316 if is_repeat {
1317 info!(
1318 "TTS voice clone: repetition detected at step {}, stopping",
1319 step
1320 );
1321 break;
1322 }
1323 }
1324
1325 let cur_hidden_len = hidden
1326 .dim(1)
1327 .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1328 let last_hidden = hidden
1329 .narrow(1, cur_hidden_len - 1, 1)
1330 .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
1331
1332 let token_tensor = Tensor::new(&[next_token], &device)
1333 .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
1334 .unsqueeze(0)
1335 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
1336 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
1337
1338 let extra_codes = self.sub_talker.predict(
1339 &last_hidden,
1340 &first_codec_embed,
1341 st_temperature(),
1342 TOP_K,
1343 )?;
1344
1345 let mut frame_codes = vec![next_token];
1346 frame_codes.extend_from_slice(&extra_codes);
1347 all_codec_tokens.push(frame_codes);
1348
1349 let mut combined_embed = first_codec_embed.clone();
1351 for (i, &code) in extra_codes.iter().enumerate() {
1352 let code_t = Tensor::new(&[code], &device)
1353 .and_then(|t| t.unsqueeze(0))
1354 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
1355 let sub_embed = code_t
1356 .apply(&self.sub_talker.codec_embeddings[i])
1357 .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
1358 combined_embed = (combined_embed + sub_embed)
1359 .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
1360 }
1361
1362 if step < trailing_text_len {
1364 let trail = icl_trailing
1365 .narrow(1, step, 1)
1366 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1367 combined_embed = (combined_embed + trail)
1368 .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
1369 } else {
1370 combined_embed = (combined_embed + &tts_pad_embed)
1371 .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
1372 }
1373
1374 hidden = self.talker.forward_step(&combined_embed)?;
1375 current_logits = self.talker.logits(&hidden)?;
1376 }
1377
1378 if all_codec_tokens.is_empty() {
1379 return Err(FerrumError::model("no codec tokens generated"));
1380 }
1381 info!(
1382 "TTS voice clone: generated {} codec frames",
1383 all_codec_tokens.len()
1384 );
1385
1386 let mut all_codes_with_ref = ref_codes.clone();
1388 all_codes_with_ref.extend_from_slice(&all_codec_tokens);
1389
1390 let num_frames = all_codes_with_ref.len();
1391 let num_groups = self.config.num_code_groups;
1392
1393 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
1395 for (t, frame) in all_codes_with_ref.iter().enumerate() {
1396 for (g, &code) in frame.iter().take(num_groups).enumerate() {
1397 flat_codes[g * num_frames + t] = code;
1398 }
1399 }
1400
1401 let codebook_size = 2048u32;
1403 for code in &mut flat_codes {
1404 if *code >= codebook_size {
1405 *code = 0;
1406 }
1407 }
1408
1409 let codes_tensor = Tensor::new(&flat_codes[..], &device)
1410 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
1411 .reshape((1, num_groups, num_frames))
1412 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
1413
1414 let waveform = self.vocoder.decode(&codes_tensor)?;
1415
1416 let samples: Vec<f32> = waveform
1417 .squeeze(0)
1418 .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
1419 .squeeze(0)
1420 .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
1421 .to_vec1()
1422 .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
1423
1424 let ref_ratio = ref_frames as f64 / num_frames as f64;
1426 let cut = (ref_ratio * samples.len() as f64) as usize;
1427 let output_samples = samples[cut..].to_vec();
1428
1429 info!(
1430 "TTS voice clone: waveform {} samples ({:.2}s), trimmed ref {} samples",
1431 output_samples.len(),
1432 output_samples.len() as f64 / SAMPLE_RATE as f64,
1433 cut,
1434 );
1435
1436 Ok(output_samples)
1437 }
1438}
1439
1440fn find_safetensor_files(dir: &std::path::Path, prefix: &str) -> Result<Vec<std::path::PathBuf>> {
1444 let single = dir.join(format!("{prefix}.safetensors"));
1446 if single.exists() {
1447 return Ok(vec![single]);
1448 }
1449
1450 let mut files: Vec<std::path::PathBuf> = Vec::new();
1452 if let Ok(entries) = std::fs::read_dir(dir) {
1453 for entry in entries.flatten() {
1454 let path = entry.path();
1455 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1456 if name.starts_with(prefix)
1457 && name.ends_with(".safetensors")
1458 && name != format!("{prefix}.safetensors")
1459 {
1460 files.push(path);
1461 }
1462 }
1463 }
1464 }
1465 files.sort();
1466
1467 if files.is_empty() {
1468 Err(FerrumError::model(format!(
1469 "no safetensors files with prefix '{prefix}' in {}",
1470 dir.display()
1471 )))
1472 } else {
1473 Ok(files)
1474 }
1475}
1476
1477fn load_bpe_tokenizer(dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
1479 let tokenizer_json = dir.join("tokenizer.json");
1481 if tokenizer_json.exists() {
1482 return tokenizers::Tokenizer::from_file(&tokenizer_json)
1483 .map_err(|e| FerrumError::model(format!("load tokenizer.json: {e}")));
1484 }
1485
1486 let vocab_path = dir.join("vocab.json");
1488 let merges_path = dir.join("merges.txt");
1489
1490 if !vocab_path.exists() || !merges_path.exists() {
1491 return Err(FerrumError::model(
1492 "tokenizer.json not found, and vocab.json + merges.txt not found either",
1493 ));
1494 }
1495
1496 let vocab_data = std::fs::read_to_string(&vocab_path)
1497 .map_err(|e| FerrumError::model(format!("read vocab.json: {e}")))?;
1498 let vocab: HashMap<String, u32> = serde_json::from_str(&vocab_data)
1499 .map_err(|e| FerrumError::model(format!("parse vocab.json: {e}")))?;
1500
1501 let merges_data = std::fs::read_to_string(&merges_path)
1502 .map_err(|e| FerrumError::model(format!("read merges.txt: {e}")))?;
1503 let merges: Vec<(String, String)> = merges_data
1504 .lines()
1505 .skip(1) .filter(|line| !line.is_empty())
1507 .filter_map(|line| {
1508 let parts: Vec<&str> = line.splitn(2, ' ').collect();
1509 if parts.len() == 2 {
1510 Some((parts[0].to_string(), parts[1].to_string()))
1511 } else {
1512 None
1513 }
1514 })
1515 .collect();
1516
1517 let bpe = tokenizers::models::bpe::BPE::from_file(
1518 vocab_path.to_str().unwrap(),
1519 merges_path.to_str().unwrap(),
1520 )
1521 .build()
1522 .map_err(|e| FerrumError::model(format!("build BPE: {e}")))?;
1523
1524 let tokenizer = tokenizers::Tokenizer::new(bpe);
1525 Ok(tokenizer)
1526}
1527
1528fn logits_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
1530 let logits = if logits.dims().len() == 3 {
1531 logits
1532 .squeeze(0)
1533 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1534 .squeeze(0)
1535 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1536 } else if logits.dims().len() == 2 {
1537 logits
1538 .squeeze(0)
1539 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1540 } else {
1541 logits.clone()
1542 };
1543
1544 logits
1545 .to_vec1()
1546 .map_err(|e| FerrumError::model(format!("logits to_vec1: {e}")))
1547}
1548
1549pub fn sample_token(
1557 logits: &[f32],
1558 temperature: f32,
1559 top_k: usize,
1560 _repetition_penalty: f32,
1561) -> u32 {
1562 if temperature < 0.01 {
1563 return argmax(logits);
1564 }
1565
1566 let vocab = logits.len();
1567
1568 let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
1570
1571 let mut filtered = scaled.clone();
1573 if top_k > 0 && top_k < vocab {
1574 let mut sorted = scaled.clone();
1575 sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1576 let threshold = sorted[top_k - 1];
1577 for v in &mut filtered {
1578 if *v < threshold {
1579 *v = f32::NEG_INFINITY;
1580 }
1581 }
1582 }
1583
1584 const TOP_P: f32 = 0.9;
1586 {
1587 let mut indices: Vec<usize> = (0..vocab).collect();
1588 indices.sort_unstable_by(|&a, &b| {
1589 filtered[b]
1590 .partial_cmp(&filtered[a])
1591 .unwrap_or(std::cmp::Ordering::Equal)
1592 });
1593
1594 let max_val = filtered[indices[0]];
1596 let exp_sorted: Vec<f32> = indices
1597 .iter()
1598 .map(|&i| (filtered[i] - max_val).exp())
1599 .collect();
1600 let sum: f32 = exp_sorted.iter().sum();
1601 let probs_sorted: Vec<f32> = exp_sorted.iter().map(|e| e / sum).collect();
1602
1603 let mut cumsum = 0.0f32;
1605 let mut cutoff_idx = vocab;
1606 for (i, &p) in probs_sorted.iter().enumerate() {
1607 cumsum += p;
1608 if cumsum > TOP_P {
1609 cutoff_idx = i + 1;
1610 break;
1611 }
1612 }
1613
1614 for &idx in &indices[cutoff_idx..] {
1616 filtered[idx] = f32::NEG_INFINITY;
1617 }
1618 }
1619
1620 let max_val = filtered.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1622 let exps: Vec<f32> = filtered.iter().map(|&v| (v - max_val).exp()).collect();
1623 let sum: f32 = exps.iter().sum();
1624 let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
1625
1626 let r = rand_f32();
1628 let mut cumulative = 0.0f32;
1629 for (i, &p) in probs.iter().enumerate() {
1630 cumulative += p;
1631 if cumulative >= r {
1632 return i as u32;
1633 }
1634 }
1635 argmax(&probs)
1637}
1638
1639fn argmax(v: &[f32]) -> u32 {
1640 v.iter()
1641 .enumerate()
1642 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1643 .map(|(i, _)| i as u32)
1644 .unwrap_or(0)
1645}
1646
1647fn rand_f32() -> f32 {
1649 use std::sync::atomic::{AtomicU64, Ordering};
1650 static COUNTER: AtomicU64 = AtomicU64::new(0);
1651
1652 let seed = std::time::SystemTime::now()
1653 .duration_since(std::time::UNIX_EPOCH)
1654 .unwrap_or_default()
1655 .subsec_nanos() as u64;
1656 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1657
1658 let state = seed
1659 .wrapping_add(count)
1660 .wrapping_mul(1103515245)
1661 .wrapping_add(12345);
1662 (state as f32) / (u64::MAX as f32)
1663}
1664
1665#[derive(Clone, Debug)]
1668#[allow(dead_code)]
1669struct DummyTtsCache;
1670
1671impl ferrum_interfaces::KvCacheHandle for DummyTtsCache {
1672 fn block_table(&self) -> &ferrum_interfaces::BlockTable {
1673 static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
1674 std::sync::OnceLock::new();
1675 EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
1676 }
1677 fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
1678 unimplemented!()
1679 }
1680 fn as_any(&self) -> &dyn std::any::Any {
1681 self
1682 }
1683 fn device(&self) -> Device {
1684 Device::CPU
1685 }
1686 fn num_layers(&self) -> usize {
1687 0
1688 }
1689 fn num_heads(&self) -> usize {
1690 0
1691 }
1692 fn head_dim(&self) -> usize {
1693 0
1694 }
1695 fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1696 Ok(None)
1697 }
1698 fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1699 Ok(None)
1700 }
1701 fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
1702 Ok(Arc::new(self.clone()))
1703 }
1704 fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
1705 ferrum_interfaces::CacheHandleStats {
1706 memory_bytes: 0,
1707 blocks_allocated: 0,
1708 tokens_stored: 0,
1709 utilization: 0.0,
1710 last_access: std::time::Instant::now(),
1711 }
1712 }
1713 fn is_valid(&self) -> bool {
1714 true
1715 }
1716 fn cache_id(&self) -> String {
1717 "tts_dummy".to_string()
1718 }
1719}
1720
1721#[async_trait]
1722impl ModelExecutor for TtsModelExecutor {
1723 fn info(&self) -> &ModelInfo {
1724 &self.info
1725 }
1726
1727 async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
1728 Err(FerrumError::model(
1729 "TTS uses synthesize(), not prefill/decode",
1730 ))
1731 }
1732
1733 async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
1734 Err(FerrumError::model(
1735 "TTS uses synthesize(), not prefill/decode",
1736 ))
1737 }
1738
1739 fn capabilities(&self) -> ExecutorCapabilities {
1740 ExecutorCapabilities {
1741 max_batch_size: 1,
1742 max_sequence_length: self.info.max_sequence_length,
1743 attention_mechanisms: vec![AttentionType::GroupedQuery],
1744 supports_dynamic_batching: false,
1745 supports_continuous_batching: false,
1746 supports_speculative_decoding: false,
1747 supports_tensor_parallelism: false,
1748 supports_pipeline_parallelism: false,
1749 supported_dtypes: vec![DataType::FP32, DataType::BF16],
1750 supported_devices: vec![self.info.device.clone()],
1751 memory_requirements: MemoryRequirements {
1752 parameter_memory: 0,
1753 activation_memory_per_token: 0,
1754 kv_cache_memory_per_token: 0,
1755 overhead_memory: 0,
1756 },
1757 }
1758 }
1759
1760 fn release_cache(&self, _: &str) {}
1761
1762 fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
1763 common::default_executor_status()
1764 }
1765}