1#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use candle_core::{DType, Device as CandleDevice, Tensor};
13use candle_nn::VarBuilder;
14use ferrum_interfaces::{
15 model_executor::{
16 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
17 PrefillInput, PrefillOutput,
18 },
19 ModelExecutor, TensorRef,
20};
21use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
22use tracing::info;
23
24use super::common;
25use crate::architectures::qwen3_tts::{Qwen3TTSTalker, SubTalker, TalkerConfig};
26use crate::architectures::qwen3_tts_backbone::TalkerBackboneBackend;
27use crate::architectures::qwen3_tts_vocoder::{Qwen3TTSVocoder, VocoderConfig};
28use crate::architectures::speaker_encoder::{mel_spectrogram_speaker_encoder, SpeakerEncoder};
29use crate::architectures::speech_tokenizer_encoder::SpeechTokenizerEncoder;
30use ferrum_quantization::NativeSafetensorsLoader;
31
32#[cfg(feature = "cuda")]
36fn install_cuda_backend_overrides(
37 cfg: &TalkerConfig,
38 model_dir: &std::path::Path,
39 talker: &mut Qwen3TTSTalker,
40 sub_talker: &mut SubTalker,
41) -> Result<()> {
42 use ferrum_kernels::backend::cuda::CudaBackend;
43 let loader: NativeSafetensorsLoader<CudaBackend> = NativeSafetensorsLoader::open(model_dir)?;
44 let talker_bb = TalkerBackboneBackend::<CudaBackend>::new(cfg, &loader)?;
45 talker.set_backend_override(Box::new(talker_bb));
46 let sub_bb = TalkerBackboneBackend::<CudaBackend>::new_code_predictor(cfg, &loader)?;
47 sub_talker.set_backend_override(Box::new(sub_bb));
48 Ok(())
49}
50
51const SAMPLE_RATE: usize = 24000;
54const MAX_CODEC_TOKENS: usize = 2000;
55
56const TEMPERATURE: f32 = 0.9;
59const TOP_K: usize = 50;
60const REPETITION_PENALTY: f32 = 1.05;
61
62fn tts_temperature() -> f32 {
63 std::env::var("FERRUM_TTS_TEMP")
64 .ok()
65 .and_then(|v| v.parse().ok())
66 .unwrap_or(TEMPERATURE)
67}
68
69fn st_temperature() -> f32 {
70 std::env::var("FERRUM_ST_TEMP")
71 .ok()
72 .and_then(|v| v.parse().ok())
73 .unwrap_or(tts_temperature())
74}
75
76pub struct TtsModelExecutor {
78 talker: Qwen3TTSTalker,
79 sub_talker: SubTalker,
80 vocoder: Qwen3TTSVocoder,
81 text_tokenizer: tokenizers::Tokenizer,
82 config: TalkerConfig,
83 info: ModelInfo,
84 speaker_encoder: Option<SpeakerEncoder>,
85 speech_tokenizer_encoder: Option<SpeechTokenizerEncoder>,
86}
87
88impl TtsModelExecutor {
89 pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
95 let dir = std::path::Path::new(model_path);
96
97 let config_json: serde_json::Value = {
99 let config_path = dir.join("config.json");
100 let data = std::fs::read_to_string(&config_path)
101 .map_err(|e| FerrumError::model(format!("read config.json: {e}")))?;
102 serde_json::from_str(&data)
103 .map_err(|e| FerrumError::model(format!("parse config.json: {e}")))?
104 };
105 let config = TalkerConfig::from_json(&config_json)?;
106
107 let text_tokenizer = load_bpe_tokenizer(dir)?;
109
110 let talker_weights = find_safetensor_files(dir, "model")?;
112 let talker_vb = unsafe {
113 VarBuilder::from_mmaped_safetensors(&talker_weights, dtype, &device)
114 .map_err(|e| FerrumError::model(format!("load talker weights: {e}")))?
115 };
116 let mut talker = Qwen3TTSTalker::load(&config, talker_vb.clone(), device.clone())?;
117
118 let mut sub_talker = SubTalker::load(&config, talker_vb.clone(), device.clone())?;
120
121 let spk_enc_dim = config_json
123 .get("speaker_encoder_config")
124 .and_then(|c| c.get("enc_dim"))
125 .and_then(|v| v.as_u64())
126 .unwrap_or(1024) as usize;
127 let speaker_encoder =
128 SpeakerEncoder::load_with_dim(talker_vb.pp("speaker_encoder"), spk_enc_dim)
129 .map_err(|e| {
130 tracing::warn!("Speaker encoder not available: {e}");
131 e
132 })
133 .ok();
134
135 let vocoder_dir = dir.join("speech_tokenizer");
137 let vocoder_weights = find_safetensor_files(&vocoder_dir, "model")?;
138 let vocoder_vb = unsafe {
139 VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &device)
140 .map_err(|e| FerrumError::model(format!("load vocoder weights: {e}")))?
141 };
142 let vocoder_config = VocoderConfig::default();
143 let vocoder = Qwen3TTSVocoder::load(&vocoder_config, vocoder_vb.clone())?;
144
145 let speech_tokenizer_encoder = if vocoder_dir.join("config.json").exists() {
149 let cpu_vb = unsafe {
150 VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &CandleDevice::Cpu)
151 .map_err(|e| FerrumError::model(format!("load encoder cpu: {e}")))?
152 };
153 SpeechTokenizerEncoder::load(cpu_vb.pp("encoder"), CandleDevice::Cpu)
154 .map_err(|e| {
155 tracing::warn!("Speech tokenizer encoder not available: {e}");
156 e
157 })
158 .ok()
159 } else {
160 None
161 };
162
163 #[cfg(feature = "cuda")]
172 if matches!(&device, CandleDevice::Cuda(_)) {
173 match install_cuda_backend_overrides(&config, dir, &mut talker, &mut sub_talker) {
174 Ok(()) => {
175 tracing::info!(
176 "TtsModelExecutor: Backend<CudaBackend> installed for Talker + SubTalker"
177 );
178 }
179 Err(e) => {
180 tracing::warn!(
181 "TtsModelExecutor: Backend<CudaBackend> install failed ({e}); \
182 falling back to candle/fused path (CUDA voice-clone may produce garbage)"
183 );
184 }
185 }
186 }
187
188 let info = ModelInfo {
189 model_id: ferrum_types::ModelId(model_path.to_string()),
190 model_type: ModelType::Custom("qwen3-tts".to_string()),
191 hidden_size: config.hidden_size,
192 vocab_size: config.vocab_size,
193 num_layers: config.num_hidden_layers,
194 num_heads: config.num_attention_heads,
195 num_kv_heads: config.num_key_value_heads,
196 num_parameters: 0,
197 max_sequence_length: config.max_position_embeddings,
198 device: match &device {
199 CandleDevice::Cpu => Device::CPU,
200 CandleDevice::Cuda(_) => Device::CUDA(0),
201 #[cfg(any(target_os = "macos", target_os = "ios"))]
202 CandleDevice::Metal(_) => Device::Metal,
203 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
204 CandleDevice::Metal(_) => Device::CPU,
205 },
206 dtype: match dtype {
207 DType::F32 => DataType::FP32,
208 DType::F16 => DataType::FP16,
209 DType::BF16 => DataType::BF16,
210 _ => DataType::FP32,
211 },
212 version: None,
213 license: None,
214 metadata: HashMap::new(),
215 };
216
217 info!(
218 "TtsModelExecutor: {} (hidden={}, layers={}, codec_groups={})",
219 model_path, config.hidden_size, config.num_hidden_layers, config.num_code_groups,
220 );
221
222 Ok(Self {
223 talker,
224 sub_talker,
225 vocoder,
226 text_tokenizer,
227 config,
228 info,
229 speaker_encoder,
230 speech_tokenizer_encoder,
231 })
232 }
233
234 pub fn synthesize(&mut self, text: &str, language: &str) -> Result<Vec<f32>> {
242 self.talker.reset();
243
244 let device = self.talker.device().clone();
245
246 let encoding = self
248 .text_tokenizer
249 .encode(text, false)
250 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
251 let content_ids: Vec<u32> = encoding.get_ids().to_vec();
252
253 if content_ids.is_empty() {
254 return Err(FerrumError::model("empty text after tokenization"));
255 }
256
257 info!("TTS: content tokens = {}", content_ids.len());
258
259 let codec_eos = self.config.codec_eos_token_id;
260 let tts_pad = self.config.tts_pad_token_id;
261 let tts_bos = self.config.tts_bos_token_id;
262 let tts_eos = self.config.tts_eos_token_id;
263
264 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
266 let t = Tensor::new(ids, &device)
267 .and_then(|t| t.unsqueeze(0))
268 .map_err(|e| FerrumError::model(format!("text ids: {e}")))?;
269 self.talker.embed_text(&t)
270 };
271 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
272 let t = Tensor::new(ids, &device)
273 .and_then(|t| t.unsqueeze(0))
274 .map_err(|e| FerrumError::model(format!("codec ids: {e}")))?;
275 self.talker.embed_codec(&t)
276 };
277
278 let im_start_id = 151644u32; let assistant_id = 77091u32; let newline_id = 198u32; let role_prefix_ids = [im_start_id, assistant_id, newline_id];
286
287 info!(
288 "TTS: role_prefix={} content={} tokens",
289 role_prefix_ids.len(),
290 content_ids.len()
291 );
292
293 let role_embed = embed_text_ids(&role_prefix_ids)?;
295
296 let resolved_lang = if language == "auto" {
298 "chinese"
299 } else {
300 language
301 };
302 let language_id = self
303 .config
304 .codec_language_id
305 .get(&resolved_lang.to_lowercase());
306 let codec_prefix_ids = if let Some(&lang_id) = language_id {
307 vec![
308 self.config.codec_think_id,
309 self.config.codec_think_bos_id,
310 lang_id,
311 self.config.codec_think_eos_id,
312 ]
313 } else {
314 vec![
315 self.config.codec_nothink_id,
316 self.config.codec_think_bos_id,
317 self.config.codec_think_eos_id,
318 ]
319 };
320 let speaker_token = if resolved_lang == "chinese" {
323 3065u32
324 } else {
325 3061u32
326 };
327 let codec_full = {
328 let mut v = codec_prefix_ids.clone();
329 v.push(speaker_token);
330 v.push(self.config.codec_pad_id);
331 v.push(self.config.codec_bos_id);
332 v
333 };
334 let codec_embed = embed_codec_ids(&codec_full)?;
335
336 let n_codec = codec_full.len();
338 let mut tts_prefix_ids = vec![tts_pad; n_codec - 1];
339 tts_prefix_ids.push(tts_bos);
340 let tts_prefix_embed = embed_text_ids(&tts_prefix_ids)?;
341
342 let codec_first = codec_embed
344 .narrow(1, 0, n_codec - 1)
345 .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
346 let n_prefix = n_codec - 1; let codec_prefix_part = codec_embed
355 .narrow(1, 0, n_prefix)
356 .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
357
358 let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
360 tts_text_prefix_ids.push(tts_bos);
361 let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
362
363 let codec_hidden = (&tts_text_embed + &codec_prefix_part)
364 .map_err(|e| FerrumError::model(format!("prefix sum: {e}")))?;
365
366 let codec_bos_embed = codec_embed
368 .narrow(1, n_prefix, 1)
369 .map_err(|e| FerrumError::model(format!("codec bos: {e}")))?;
370
371 let first_text_combined = if !content_ids.is_empty() {
373 let first_text_embed = embed_text_ids(&content_ids[..1])?;
374 (&first_text_embed + &codec_bos_embed)
375 .map_err(|e| FerrumError::model(format!("first text+bos: {e}")))?
376 } else {
377 codec_bos_embed.clone()
378 };
379
380 let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
382 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
383
384 let plen = prefill_embeds.dim(1).unwrap_or(0);
385 info!("TTS: prefill_len = {}", plen);
386 if let Ok(v) = prefill_embeds
388 .narrow(0, 0, 1)
389 .and_then(|t| t.narrow(1, 0, 1))
390 .and_then(|t| t.narrow(2, 0, 5))
391 .and_then(|t| t.flatten_all())
392 .and_then(|t| t.to_vec1::<f32>())
393 {
394 info!(" prefill_input pos0[:5] = {:?}", v);
395 }
396 if plen > 0 {
397 if let Ok(v) = prefill_embeds
398 .narrow(0, 0, 1)
399 .and_then(|t| t.narrow(1, plen - 1, 1))
400 .and_then(|t| t.narrow(2, 0, 5))
401 .and_then(|t| t.flatten_all())
402 .and_then(|t| t.to_vec1::<f32>())
403 {
404 info!(" prefill_input pos-1[:5] = {:?}", v);
405 }
406 }
407
408 let mut trailing_ids: Vec<u32> = if content_ids.len() > 1 {
410 content_ids[1..].to_vec()
411 } else {
412 Vec::new()
413 };
414 trailing_ids.push(tts_eos);
415 let trailing_text_embeds = embed_text_ids(&trailing_ids)?;
416 let trailing_text_len = trailing_text_embeds
417 .dim(1)
418 .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
419 let tts_pad_embed = embed_text_ids(&[tts_pad])?;
420
421 info!("TTS: trailing_text_len = {}", trailing_text_len);
422
423 let mut hidden = self.talker.forward_step(&prefill_embeds)?;
425
426 let current_logits = self.talker.logits(
428 &hidden
429 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
430 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
431 )?;
432
433 let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
435 let mut current_logits = current_logits;
436
437 let suppress_start = self.config.vocab_size.saturating_sub(1024);
439 let suppress_end = self.config.vocab_size;
440 let mut generated_tokens: Vec<u32> = Vec::new();
441
442 for step in 0..MAX_CODEC_TOKENS {
443 let mut logits_vec = logits_to_vec(¤t_logits)?;
445 for i in suppress_start..suppress_end.min(logits_vec.len()) {
447 if i as u32 != codec_eos {
448 logits_vec[i] = f32::NEG_INFINITY;
449 }
450 }
451 for &prev_tok in &generated_tokens {
453 let idx = prev_tok as usize;
454 if idx < logits_vec.len() {
455 if logits_vec[idx] > 0.0 {
456 logits_vec[idx] /= REPETITION_PENALTY;
457 } else {
458 logits_vec[idx] *= REPETITION_PENALTY;
459 }
460 }
461 }
462 let next_token =
463 sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
464
465 if step < 10 {
466 let argmax_tok = logits_vec
468 .iter()
469 .enumerate()
470 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
471 .map(|(i, v)| (i, *v))
472 .unwrap_or((0, 0.0));
473 info!(
474 "TOKEN step={} sampled={} argmax=({}, {:.2})",
475 step, next_token, argmax_tok.0, argmax_tok.1
476 );
477 }
478
479 generated_tokens.push(next_token);
480
481 if next_token == codec_eos {
483 info!("TTS: codec EOS at step {}", step);
484 break;
485 }
486
487 let last_hidden = hidden
489 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
490 .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
491
492 let token_tensor = Tensor::new(&[next_token], &device)
494 .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
495 .unsqueeze(0)
496 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
497 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
498
499 let st_t0 = std::time::Instant::now();
501 let extra_codes = self.sub_talker.predict(
502 &last_hidden,
503 &first_codec_embed,
504 st_temperature(),
505 TOP_K,
506 )?;
507 if step == 0 {
508 info!(
509 " SubTalker: {:.1}ms",
510 st_t0.elapsed().as_secs_f64() * 1000.0
511 );
512 }
513
514 let mut frame_codes = vec![next_token];
515 frame_codes.extend_from_slice(&extra_codes);
516 all_codec_tokens.push(frame_codes);
517
518 let mut combined_embed = first_codec_embed.clone();
521 for (i, &code) in extra_codes.iter().enumerate() {
522 let code_t = Tensor::new(&[code], &device)
523 .and_then(|t| t.unsqueeze(0))
524 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
525 let sub_embed = code_t
526 .apply(&self.sub_talker.codec_embeddings[i])
527 .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
528 combined_embed = (combined_embed + sub_embed)
529 .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
530 }
531
532 if step == 0 {
533 if let Ok(v) = combined_embed
534 .flatten_all()
535 .and_then(|t| t.narrow(0, 0, 5))
536 .and_then(|t| t.to_vec1::<f32>())
537 {
538 info!("STEP0 codec_sum[:5] = {:?} (before trailing)", v);
539 }
540 }
541 if step < trailing_text_len {
546 let trail = trailing_text_embeds
547 .narrow(1, step, 1)
548 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
549 combined_embed = (combined_embed + trail)
550 .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
551 } else {
552 combined_embed = (combined_embed + &tts_pad_embed)
553 .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
554 }
555
556 if step == 0 {
558 if let Ok(v) = first_codec_embed
559 .flatten_all()
560 .and_then(|t| t.narrow(0, 0, 5))
561 .and_then(|t| t.to_vec1::<f32>())
562 {
563 info!("STEP0 semantic[:5] = {:?}", v);
564 }
565 if let Ok(v) = combined_embed
566 .flatten_all()
567 .and_then(|t| t.narrow(0, 0, 5))
568 .and_then(|t| t.to_vec1::<f32>())
569 {
570 info!("STEP0 combined[:5] = {:?}", v);
571 }
572 }
573 let tk_t0 = std::time::Instant::now();
575 hidden = self.talker.forward_step(&combined_embed)?;
576 current_logits = self.talker.logits(&hidden)?;
577 if step == 0 {
578 info!(
579 " Talker step: {:.1}ms",
580 tk_t0.elapsed().as_secs_f64() * 1000.0
581 );
582 }
583 }
584
585 if all_codec_tokens.is_empty() {
586 return Err(FerrumError::model("no codec tokens generated"));
587 }
588
589 info!("TTS: generated {} codec frames", all_codec_tokens.len());
590
591 let num_frames = all_codec_tokens.len();
593 let num_groups = self.config.num_code_groups;
594 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
595 for (t, frame) in all_codec_tokens.iter().enumerate() {
596 for (g, &code) in frame.iter().enumerate() {
597 flat_codes[g * num_frames + t] = code;
598 }
599 }
600
601 let codebook_size = 2048u32;
604 for code in &mut flat_codes {
605 if *code >= codebook_size {
606 *code = 0; }
608 }
609
610 let codes_tensor = Tensor::new(&flat_codes[..], &device)
611 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
612 .reshape((1, num_groups, num_frames))
613 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
614
615 let waveform = self.vocoder.decode(&codes_tensor)?;
617
618 let samples: Vec<f32> = waveform
620 .squeeze(0)
621 .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
622 .squeeze(0)
623 .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
624 .to_vec1()
625 .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
626
627 info!(
628 "TTS: waveform {} samples ({:.2}s @ {}Hz)",
629 samples.len(),
630 samples.len() as f64 / SAMPLE_RATE as f64,
631 SAMPLE_RATE,
632 );
633
634 Ok(samples)
635 }
636
637 fn decode_frames(&mut self, frames: &[Vec<u32>], device: &CandleDevice) -> Result<Vec<f32>> {
640 let num_frames = frames.len();
641 if num_frames == 0 {
642 return Ok(vec![]);
643 }
644 let num_groups = self.config.num_code_groups;
645 let codebook_size = 2048u32;
646
647 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
648 for (t, frame) in frames.iter().enumerate() {
649 for (g, &code) in frame.iter().take(num_groups).enumerate() {
650 flat_codes[g * num_frames + t] = if code >= codebook_size { 0 } else { code };
651 }
652 }
653
654 let codes_tensor = Tensor::new(&flat_codes[..], device)
655 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
656 .reshape((1, num_groups, num_frames))
657 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
658
659 let waveform = self.vocoder.decode(&codes_tensor)?;
660 waveform
661 .squeeze(0)
662 .and_then(|t| t.squeeze(0))
663 .and_then(|t| t.to_vec1())
664 .map_err(|e| FerrumError::model(format!("waveform extract: {e}")))
665 }
666
667 pub fn synthesize_streaming<F: FnMut(usize, &[f32])>(
672 &mut self,
673 text: &str,
674 language: &str,
675 chunk_frames: usize,
676 mut on_chunk: F,
677 ) -> Result<Vec<Vec<f32>>> {
678 self.talker.reset();
681 let device = self.talker.device().clone();
682
683 let encoding = self
684 .text_tokenizer
685 .encode(text, false)
686 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
687 let content_ids: Vec<u32> = encoding.get_ids().to_vec();
688 if content_ids.is_empty() {
689 return Err(FerrumError::model("empty text after tokenization"));
690 }
691
692 let codec_eos = self.config.codec_eos_token_id;
693 let tts_pad = self.config.tts_pad_token_id;
694 let tts_bos = self.config.tts_bos_token_id;
695 let tts_eos = self.config.tts_eos_token_id;
696
697 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
699 let t = Tensor::new(ids, &device)
700 .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
701 .unsqueeze(0)
702 .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
703 self.talker.embed_text(&t)
704 };
705 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
706 let t = Tensor::new(ids, &device)
707 .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
708 .unsqueeze(0)
709 .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
710 self.talker.embed_codec(&t)
711 };
712
713 let resolved_lang = if language.eq_ignore_ascii_case("auto") {
715 "chinese"
716 } else {
717 language
718 };
719 let language_id = self
720 .config
721 .codec_language_id
722 .get(&resolved_lang.to_lowercase());
723 let codec_prefix_ids = if let Some(&lang_id) = language_id {
724 vec![
725 self.config.codec_think_id,
726 self.config.codec_think_bos_id,
727 lang_id,
728 self.config.codec_think_eos_id,
729 ]
730 } else {
731 vec![
732 self.config.codec_nothink_id,
733 self.config.codec_think_bos_id,
734 self.config.codec_think_eos_id,
735 ]
736 };
737 let speaker_token = if resolved_lang == "chinese" {
738 3065u32
739 } else {
740 3061u32
741 };
742 let mut codec_ids = codec_prefix_ids;
743 codec_ids.push(speaker_token);
744 codec_ids.push(self.config.codec_pad_id);
745 codec_ids.push(self.config.codec_bos_id);
746 let codec_embed = embed_codec_ids(&codec_ids)?;
747 let n_codec = codec_embed
748 .dim(1)
749 .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
750 let n_prefix = n_codec - 1;
751 let codec_prefix_part = codec_embed
752 .narrow(1, 0, n_prefix)
753 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
754
755 let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
756 tts_text_prefix_ids.push(tts_bos);
757 let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
758 let codec_hidden = (&tts_text_embed + &codec_prefix_part)
759 .map_err(|e| FerrumError::model(format!("sum: {e}")))?;
760 let codec_bos_embed = codec_embed
761 .narrow(1, n_prefix, 1)
762 .map_err(|e| FerrumError::model(format!("bos: {e}")))?;
763
764 let role_ids: &[u32] = &[151644, 77091, 198]; let role_embed = embed_text_ids(role_ids)?;
766
767 let first_text_combined = if !content_ids.is_empty() {
768 let first_text_embed = embed_text_ids(&content_ids[..1])?;
769 (&first_text_embed + &codec_bos_embed)
770 .map_err(|e| FerrumError::model(format!("first: {e}")))?
771 } else {
772 codec_bos_embed.clone()
773 };
774
775 let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
776 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
777
778 let trailing_text_embeds = if content_ids.len() > 1 {
780 let remaining = embed_text_ids(&content_ids[1..])?;
781 let eos = embed_text_ids(&[tts_eos])?;
782 Tensor::cat(&[&remaining, &eos], 1)
783 .map_err(|e| FerrumError::model(format!("trailing: {e}")))?
784 } else {
785 embed_text_ids(&[tts_eos])?
786 };
787 let trailing_text_len = trailing_text_embeds
788 .dim(1)
789 .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
790 let tts_pad_embed = embed_text_ids(&[tts_pad])?;
791
792 let mut hidden = self.talker.forward_step(&prefill_embeds)?;
794 let mut current_logits = self.talker.logits(
795 &hidden
796 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
797 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
798 )?;
799
800 let suppress_start = self.config.vocab_size.saturating_sub(1024);
802 let suppress_end = self.config.vocab_size;
803 let mut generated_tokens: Vec<u32> = Vec::new();
804 let mut frame_buffer: Vec<Vec<u32>> = Vec::new();
805 let mut audio_chunks: Vec<Vec<f32>> = Vec::new();
806
807 for step in 0..MAX_CODEC_TOKENS {
808 let mut logits_vec = logits_to_vec(¤t_logits)?;
809 for i in suppress_start..suppress_end.min(logits_vec.len()) {
810 if i as u32 != codec_eos {
811 logits_vec[i] = f32::NEG_INFINITY;
812 }
813 }
814 for &prev_tok in &generated_tokens {
815 let idx = prev_tok as usize;
816 if idx < logits_vec.len() {
817 if logits_vec[idx] > 0.0 {
818 logits_vec[idx] /= REPETITION_PENALTY;
819 } else {
820 logits_vec[idx] *= REPETITION_PENALTY;
821 }
822 }
823 }
824 let next_token =
825 sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
826 generated_tokens.push(next_token);
827
828 if next_token == codec_eos {
829 info!("TTS streaming: EOS at step {}", step);
830 break;
831 }
832
833 let last_hidden = hidden
834 .narrow(1, hidden.dim(1).unwrap() - 1, 1)
835 .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
836 let token_tensor = Tensor::new(&[next_token], &device)
837 .and_then(|t| t.unsqueeze(0))
838 .map_err(|e| FerrumError::model(format!("tok: {e}")))?;
839 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
840
841 let extra_codes = self.sub_talker.predict(
842 &last_hidden,
843 &first_codec_embed,
844 st_temperature(),
845 TOP_K,
846 )?;
847
848 let mut frame = vec![next_token];
849 frame.extend_from_slice(&extra_codes);
850 frame_buffer.push(frame);
851
852 if frame_buffer.len() >= chunk_frames {
854 let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
855 on_chunk(audio_chunks.len(), &chunk_audio);
856 audio_chunks.push(chunk_audio);
857 frame_buffer.clear();
858 }
859
860 let mut combined_embed = first_codec_embed.clone();
862 for (i, &code) in extra_codes.iter().enumerate() {
863 let code_t = Tensor::new(&[code], &device)
864 .and_then(|t| t.unsqueeze(0))
865 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
866 let sub_embed = code_t
867 .apply(&self.sub_talker.codec_embeddings[i])
868 .map_err(|e| FerrumError::model(format!("sub: {e}")))?;
869 combined_embed = (combined_embed + sub_embed)
870 .map_err(|e| FerrumError::model(format!("add: {e}")))?;
871 }
872 if step < trailing_text_len {
873 let trail = trailing_text_embeds
874 .narrow(1, step, 1)
875 .map_err(|e| FerrumError::model(format!("trail: {e}")))?;
876 combined_embed = (combined_embed + trail)
877 .map_err(|e| FerrumError::model(format!("add trail: {e}")))?;
878 } else {
879 combined_embed = (combined_embed + &tts_pad_embed)
880 .map_err(|e| FerrumError::model(format!("pad: {e}")))?;
881 }
882
883 hidden = self.talker.forward_step(&combined_embed)?;
884 current_logits = self.talker.logits(&hidden)?;
885 }
886
887 if !frame_buffer.is_empty() {
889 let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
890 on_chunk(audio_chunks.len(), &chunk_audio);
891 audio_chunks.push(chunk_audio);
892 }
893
894 Ok(audio_chunks)
895 }
896
897 pub fn sample_rate(&self) -> usize {
899 SAMPLE_RATE
900 }
901
902 pub fn config(&self) -> &TalkerConfig {
903 &self.config
904 }
905
906 pub fn synthesize_voice_clone(
914 &mut self,
915 text: &str,
916 language: &str,
917 ref_audio_path: &str,
918 ref_text: &str,
919 ) -> Result<Vec<f32>> {
920 let device = self.talker.device().clone();
921
922 let ref_pcm = if let Ok(path) = std::env::var("FERRUM_REF_PCM") {
924 let data = std::fs::read(&path)
925 .map_err(|e| FerrumError::model(format!("read ref pcm: {e}")))?;
926 let pcm: Vec<f32> = data
927 .chunks(4)
928 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
929 .collect();
930 info!("Loaded ref PCM override: {} samples", pcm.len());
931 pcm
932 } else {
933 crate::audio_processor::load_audio_at_rate(ref_audio_path, 24000)?
934 };
935 info!(
936 "TTS voice clone: loaded ref audio {} samples ({:.2}s)",
937 ref_pcm.len(),
938 ref_pcm.len() as f64 / 24000.0
939 );
940
941 let t0 = std::time::Instant::now();
942 let speaker_encoder = self
944 .speaker_encoder
945 .as_ref()
946 .ok_or_else(|| FerrumError::model("speaker encoder not loaded"))?;
947 let mel = mel_spectrogram_speaker_encoder(&ref_pcm);
948 let n_mel_frames = mel.len() / 128;
949 let mel_tensor = Tensor::from_vec(mel, (1, n_mel_frames, 128), &device)
952 .map_err(|e| FerrumError::model(format!("mel tensor: {e}")))?;
953 let spk_embed = speaker_encoder.forward(&mel_tensor)?;
954 info!(
955 "Step 2 (speaker embed): {:.1}ms",
956 t0.elapsed().as_secs_f64() * 1000.0
957 );
958 let spk_embed = spk_embed
960 .unsqueeze(0)
961 .map_err(|e| FerrumError::model(format!("spk unsqueeze(0): {e}")))?
962 .unsqueeze(0)
963 .map_err(|e| FerrumError::model(format!("spk unsqueeze(0) 2: {e}")))?;
964
965 let t1 = std::time::Instant::now();
966 let speech_enc = self
968 .speech_tokenizer_encoder
969 .as_ref()
970 .ok_or_else(|| FerrumError::model("speech tokenizer encoder not loaded"))?;
971 let ref_codes = if let Ok(path) = std::env::var("FERRUM_REF_CODES") {
973 let data = std::fs::read(&path)
974 .map_err(|e| FerrumError::model(format!("read ref codes: {e}")))?;
975 let u32s: Vec<u32> = data
976 .chunks(4)
977 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
978 .collect();
979 let ncb = self.config.num_code_groups;
980 let nframes = u32s.len() / ncb;
981 info!(
982 "Loaded pre-computed ref codes: {} frames from {}",
983 nframes, path
984 );
985 u32s.chunks(ncb).map(|c| c.to_vec()).collect()
986 } else {
987 let codes = speech_enc.encode(&ref_pcm)?;
988 info!(
989 "Step 3 (speech tokenizer): {:.1}ms",
990 t1.elapsed().as_secs_f64() * 1000.0
991 );
992 codes
993 };
994 let ref_frames = ref_codes.len();
995 info!(
996 "TTS voice clone: ref_frames={}, spk_embed loaded",
997 ref_frames
998 );
999 for i in 0..ref_frames.min(5) {
1001 info!(" rust codec frame {}: {:?}", i, &ref_codes[i]);
1002 }
1003
1004 info!(
1005 "Step 3 (speech tokenizer): {:.1}ms",
1006 t1.elapsed().as_secs_f64() * 1000.0
1007 );
1008 let t2 = std::time::Instant::now();
1009 let chat_text = format!("<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n");
1011 let encoding = self
1012 .text_tokenizer
1013 .encode(chat_text.as_str(), false)
1014 .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
1015 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
1016 let role_ids = &input_ids[..3];
1018 let text_content_ids = &input_ids[3..input_ids.len().saturating_sub(5)];
1019
1020 let ref_chat_text = format!("<|im_start|>assistant\n{ref_text}<|im_end|>\n");
1022 let ref_encoding = self
1023 .text_tokenizer
1024 .encode(ref_chat_text.as_str(), false)
1025 .map_err(|e| FerrumError::model(format!("tokenize ref: {e}")))?;
1026 let ref_ids: Vec<u32> = ref_encoding.get_ids().to_vec();
1027 let ref_text_ids = &ref_ids[3..ref_ids.len().saturating_sub(2)];
1029
1030 self.talker.reset();
1032
1033 let tts_bos = self.config.tts_bos_token_id;
1034 let tts_eos = self.config.tts_eos_token_id;
1035 let tts_pad = self.config.tts_pad_token_id;
1036 let codec_bos = self.config.codec_bos_id;
1037 let codec_eos = self.config.codec_eos_token_id;
1038 let codec_pad = self.config.codec_pad_id;
1039
1040 let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
1042 let t = Tensor::new(ids, &device)
1043 .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
1044 .unsqueeze(0)
1045 .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
1046 self.talker.embed_codec(&t)
1047 };
1048 let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
1049 let t = Tensor::new(ids, &device)
1050 .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
1051 .unsqueeze(0)
1052 .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
1053 self.talker.embed_text(&t)
1054 };
1055
1056 let tts_special = embed_text_ids(&[tts_bos, tts_eos, tts_pad])?;
1058 let tts_bos_embed = tts_special
1059 .narrow(1, 0, 1)
1060 .map_err(|e| FerrumError::model(format!("tts_bos narrow: {e}")))?;
1061 let tts_eos_embed = tts_special
1062 .narrow(1, 1, 1)
1063 .map_err(|e| FerrumError::model(format!("tts_eos narrow: {e}")))?;
1064 let tts_pad_embed = tts_special
1065 .narrow(1, 2, 1)
1066 .map_err(|e| FerrumError::model(format!("tts_pad narrow: {e}")))?;
1067
1068 let resolved_lang = if language.eq_ignore_ascii_case("auto") {
1070 "chinese"
1071 } else {
1072 language
1073 };
1074 let language_id = self
1075 .config
1076 .codec_language_id
1077 .get(&resolved_lang.to_lowercase());
1078
1079 let codec_prefix_ids = if let Some(&lang_id) = language_id {
1081 vec![
1082 self.config.codec_think_id,
1083 self.config.codec_think_bos_id,
1084 lang_id,
1085 self.config.codec_think_eos_id,
1086 ]
1087 } else {
1088 vec![
1089 self.config.codec_nothink_id,
1090 self.config.codec_think_bos_id,
1091 self.config.codec_think_eos_id,
1092 ]
1093 };
1094 let codec_prefix_embed = embed_codec_ids(&codec_prefix_ids)?;
1095
1096 let codec_suffix_embed = embed_codec_ids(&[codec_pad, codec_bos])?;
1098
1099 let codec_input = Tensor::cat(&[&codec_prefix_embed, &spk_embed, &codec_suffix_embed], 1)
1101 .map_err(|e| FerrumError::model(format!("codec_input cat: {e}")))?;
1102 let codec_len = codec_input
1103 .dim(1)
1104 .map_err(|e| FerrumError::model(format!("codec_len dim: {e}")))?;
1105
1106 let role_embed = embed_text_ids(role_ids)?;
1108
1109 let n_pads = codec_len - 2;
1111 let mut text_prefix_parts = Vec::new();
1112 for _ in 0..n_pads {
1113 text_prefix_parts.push(tts_pad_embed.clone());
1114 }
1115 text_prefix_parts.push(tts_bos_embed.clone());
1116 let text_prefix_refs: Vec<&Tensor> = text_prefix_parts.iter().collect();
1117 let text_prefix = Tensor::cat(&text_prefix_refs, 1)
1118 .map_err(|e| FerrumError::model(format!("text_prefix cat: {e}")))?;
1119 let codec_prefix_part = codec_input
1120 .narrow(1, 0, codec_len - 1)
1121 .map_err(|e| FerrumError::model(format!("codec prefix narrow: {e}")))?;
1122 let text_codec_prefix = (&text_prefix + &codec_prefix_part)
1123 .map_err(|e| FerrumError::model(format!("text+codec prefix sum: {e}")))?;
1124
1125 let prefill_embed = Tensor::cat(&[&role_embed, &text_codec_prefix], 1)
1127 .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
1128
1129 let t3 = std::time::Instant::now();
1130
1131 let all_text_ids: Vec<u32> = ref_text_ids
1134 .iter()
1135 .chain(text_content_ids.iter())
1136 .copied()
1137 .collect();
1138 let text_embed = embed_text_ids(&all_text_ids)?;
1139 let text_embed_with_eos = Tensor::cat(&[&text_embed, &tts_eos_embed], 1)
1140 .map_err(|e| FerrumError::model(format!("text+eos cat: {e}")))?;
1141 let text_len = text_embed_with_eos
1142 .dim(1)
1143 .map_err(|e| FerrumError::model(format!("text_len dim: {e}")))?;
1144
1145 let t_codec_start = std::time::Instant::now();
1148 let ncg = self.config.num_code_groups;
1149 let first_codes: Vec<u32> = ref_codes.iter().map(|f| f[0]).collect();
1150 let codec_frames_cat = {
1151 let mut sum = embed_codec_ids(&first_codes)?;
1153 for cb in 0..(ncg - 1) {
1155 let codes: Vec<u32> = ref_codes.iter().map(|f| f[cb + 1]).collect();
1156 let codes_t = Tensor::new(codes.as_slice(), &device)
1157 .map_err(|e| FerrumError::model(format!("batch codes: {e}")))?
1158 .unsqueeze(0)
1159 .map_err(|e| FerrumError::model(format!("batch unsqueeze: {e}")))?;
1160 let sub_embed = codes_t
1161 .apply(&self.sub_talker.codec_embeddings[cb])
1162 .map_err(|e| FerrumError::model(format!("batch sub_embed: {e}")))?;
1163 sum =
1164 (sum + sub_embed).map_err(|e| FerrumError::model(format!("batch add: {e}")))?;
1165 }
1166 sum
1167 };
1168 info!(
1169 "Codec embedding: {:.1}ms ({} frames × {} codebooks)",
1170 t_codec_start.elapsed().as_secs_f64() * 1000.0,
1171 ref_codes.len(),
1172 ncg
1173 );
1174
1175 let t_merge = std::time::Instant::now();
1177 let codec_bos_for_icl = embed_codec_ids(&[codec_bos])?;
1178 let icl_codec = Tensor::cat(&[&codec_bos_for_icl, &codec_frames_cat], 1)
1179 .map_err(|e| FerrumError::model(format!("icl_codec cat: {e}")))?;
1180 let codec_icl_len = icl_codec
1181 .dim(1)
1182 .map_err(|e| FerrumError::model(format!("codec_icl_len dim: {e}")))?;
1183
1184 let icl_trailing: Tensor;
1186 let icl_embed: Tensor;
1187 if text_len > codec_icl_len {
1188 let text_part = text_embed_with_eos
1189 .narrow(1, 0, codec_icl_len)
1190 .map_err(|e| FerrumError::model(format!("text_part narrow: {e}")))?;
1191 icl_embed = (&text_part + &icl_codec)
1192 .map_err(|e| FerrumError::model(format!("text+codec sum: {e}")))?;
1193 icl_trailing = text_embed_with_eos
1194 .narrow(1, codec_icl_len, text_len - codec_icl_len)
1195 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1196 } else {
1197 let n_pad = codec_icl_len - text_len;
1198 let text_padded = if n_pad > 0 {
1199 let pad_block = tts_pad_embed
1200 .expand((1, n_pad, self.config.hidden_size))
1201 .map_err(|e| FerrumError::model(format!("pad expand: {e}")))?;
1202 Tensor::cat(&[&text_embed_with_eos, &pad_block], 1)
1203 .map_err(|e| FerrumError::model(format!("text_padded cat: {e}")))?
1204 } else {
1205 text_embed_with_eos.clone()
1206 };
1207 icl_embed = (&text_padded + &icl_codec)
1208 .map_err(|e| FerrumError::model(format!("padded+codec sum: {e}")))?;
1209 icl_trailing = tts_pad_embed.clone();
1210 }
1211 let trailing_text_len = icl_trailing
1212 .dim(1)
1213 .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
1214
1215 let _prefill_out = self.talker.forward_step(&prefill_embed)?;
1218 let t_icl = std::time::Instant::now();
1219 let icl_hidden = self.talker.forward_step(&icl_embed)?;
1220 let icl_len = icl_hidden
1221 .dim(1)
1222 .map_err(|e| FerrumError::model(format!("icl_hidden dim: {e}")))?;
1223 info!(
1224 "ICL block: {:.1}ms ({} tokens), trailing={}",
1225 t_icl.elapsed().as_secs_f64() * 1000.0,
1226 icl_len,
1227 trailing_text_len
1228 );
1229
1230 let mut hidden = icl_hidden;
1232 let hidden_len = hidden
1233 .dim(1)
1234 .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1235 let last_hidden = hidden
1236 .narrow(1, hidden_len - 1, 1)
1237 .map_err(|e| FerrumError::model(format!("narrow last: {e}")))?;
1238 if let Ok(v) = last_hidden.flatten_all().and_then(|t| t.to_vec1::<f32>()) {}
1239 let current_logits = self.talker.logits(&last_hidden)?;
1240 {}
1241
1242 let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
1244 let mut current_logits = current_logits;
1245
1246 let suppress_start = self.config.vocab_size.saturating_sub(1024);
1248 let suppress_end = self.config.vocab_size;
1249
1250 const ICL_REPETITION_PENALTY: f32 = 1.5;
1253 const ICL_FRAMES_PER_TOKEN: usize = 6;
1254 const ICL_MIN_FRAMES: usize = 75;
1255 let max_icl_tokens = ICL_MIN_FRAMES.max(text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1256 let mut generated_tokens: Vec<u32> = Vec::new();
1257
1258 for step in 0..max_icl_tokens {
1259 let mut logits_vec = logits_to_vec(¤t_logits)?;
1260 for i in suppress_start..suppress_end.min(logits_vec.len()) {
1262 if i as u32 != codec_eos {
1263 logits_vec[i] = f32::NEG_INFINITY;
1264 }
1265 }
1266 let min_frames: usize = std::env::var("FERRUM_TTS_MIN_FRAMES")
1270 .ok()
1271 .and_then(|v| v.parse().ok())
1272 .unwrap_or_else(|| text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1273 if step < min_frames {
1274 if let Some(v) = logits_vec.get_mut(codec_eos as usize) {
1275 *v = f32::NEG_INFINITY;
1276 }
1277 }
1278 for &prev_tok in &generated_tokens {
1280 let idx = prev_tok as usize;
1281 if idx < logits_vec.len() {
1282 if logits_vec[idx] > 0.0 {
1283 logits_vec[idx] /= ICL_REPETITION_PENALTY;
1284 } else {
1285 logits_vec[idx] *= ICL_REPETITION_PENALTY;
1286 }
1287 }
1288 }
1289 let next_token = sample_token(
1290 &logits_vec,
1291 tts_temperature(),
1292 TOP_K,
1293 ICL_REPETITION_PENALTY,
1294 );
1295
1296 generated_tokens.push(next_token);
1297
1298 if next_token == codec_eos {
1299 info!("TTS voice clone: codec EOS at step {}", step);
1300 break;
1301 }
1302
1303 if generated_tokens.len() >= 6 {
1305 let n = generated_tokens.len();
1306 let mut is_repeat = false;
1307 for pat_len in 1..=4 {
1308 if n >= pat_len * 3 {
1309 let a = &generated_tokens[n - pat_len * 3..n - pat_len * 2];
1310 let b = &generated_tokens[n - pat_len * 2..n - pat_len];
1311 let c = &generated_tokens[n - pat_len..n];
1312 if a == b && b == c {
1313 is_repeat = true;
1314 break;
1315 }
1316 }
1317 }
1318 if is_repeat {
1319 info!(
1320 "TTS voice clone: repetition detected at step {}, stopping",
1321 step
1322 );
1323 break;
1324 }
1325 }
1326
1327 let cur_hidden_len = hidden
1328 .dim(1)
1329 .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1330 let last_hidden = hidden
1331 .narrow(1, cur_hidden_len - 1, 1)
1332 .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
1333
1334 let token_tensor = Tensor::new(&[next_token], &device)
1335 .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
1336 .unsqueeze(0)
1337 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
1338 let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
1339
1340 let extra_codes = self.sub_talker.predict(
1341 &last_hidden,
1342 &first_codec_embed,
1343 st_temperature(),
1344 TOP_K,
1345 )?;
1346
1347 let mut frame_codes = vec![next_token];
1348 frame_codes.extend_from_slice(&extra_codes);
1349 all_codec_tokens.push(frame_codes);
1350
1351 let mut combined_embed = first_codec_embed.clone();
1353 for (i, &code) in extra_codes.iter().enumerate() {
1354 let code_t = Tensor::new(&[code], &device)
1355 .and_then(|t| t.unsqueeze(0))
1356 .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
1357 let sub_embed = code_t
1358 .apply(&self.sub_talker.codec_embeddings[i])
1359 .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
1360 combined_embed = (combined_embed + sub_embed)
1361 .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
1362 }
1363
1364 if step < trailing_text_len {
1366 let trail = icl_trailing
1367 .narrow(1, step, 1)
1368 .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1369 combined_embed = (combined_embed + trail)
1370 .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
1371 } else {
1372 combined_embed = (combined_embed + &tts_pad_embed)
1373 .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
1374 }
1375
1376 hidden = self.talker.forward_step(&combined_embed)?;
1377 current_logits = self.talker.logits(&hidden)?;
1378 }
1379
1380 if all_codec_tokens.is_empty() {
1381 return Err(FerrumError::model("no codec tokens generated"));
1382 }
1383 info!(
1384 "TTS voice clone: generated {} codec frames",
1385 all_codec_tokens.len()
1386 );
1387
1388 let mut all_codes_with_ref = ref_codes.clone();
1390 all_codes_with_ref.extend_from_slice(&all_codec_tokens);
1391
1392 let num_frames = all_codes_with_ref.len();
1393 let num_groups = self.config.num_code_groups;
1394
1395 let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
1397 for (t, frame) in all_codes_with_ref.iter().enumerate() {
1398 for (g, &code) in frame.iter().take(num_groups).enumerate() {
1399 flat_codes[g * num_frames + t] = code;
1400 }
1401 }
1402
1403 let codebook_size = 2048u32;
1405 for code in &mut flat_codes {
1406 if *code >= codebook_size {
1407 *code = 0;
1408 }
1409 }
1410
1411 let codes_tensor = Tensor::new(&flat_codes[..], &device)
1412 .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
1413 .reshape((1, num_groups, num_frames))
1414 .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
1415
1416 let waveform = self.vocoder.decode(&codes_tensor)?;
1417
1418 let samples: Vec<f32> = waveform
1419 .squeeze(0)
1420 .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
1421 .squeeze(0)
1422 .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
1423 .to_vec1()
1424 .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
1425
1426 let ref_ratio = ref_frames as f64 / num_frames as f64;
1428 let cut = (ref_ratio * samples.len() as f64) as usize;
1429 let output_samples = samples[cut..].to_vec();
1430
1431 info!(
1432 "TTS voice clone: waveform {} samples ({:.2}s), trimmed ref {} samples",
1433 output_samples.len(),
1434 output_samples.len() as f64 / SAMPLE_RATE as f64,
1435 cut,
1436 );
1437
1438 Ok(output_samples)
1439 }
1440}
1441
1442fn find_safetensor_files(dir: &std::path::Path, prefix: &str) -> Result<Vec<std::path::PathBuf>> {
1446 let single = dir.join(format!("{prefix}.safetensors"));
1448 if single.exists() {
1449 return Ok(vec![single]);
1450 }
1451
1452 let mut files: Vec<std::path::PathBuf> = Vec::new();
1454 if let Ok(entries) = std::fs::read_dir(dir) {
1455 for entry in entries.flatten() {
1456 let path = entry.path();
1457 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1458 if name.starts_with(prefix)
1459 && name.ends_with(".safetensors")
1460 && name != format!("{prefix}.safetensors")
1461 {
1462 files.push(path);
1463 }
1464 }
1465 }
1466 }
1467 files.sort();
1468
1469 if files.is_empty() {
1470 Err(FerrumError::model(format!(
1471 "no safetensors files with prefix '{prefix}' in {}",
1472 dir.display()
1473 )))
1474 } else {
1475 Ok(files)
1476 }
1477}
1478
1479fn load_bpe_tokenizer(dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
1481 let tokenizer_json = dir.join("tokenizer.json");
1483 if tokenizer_json.exists() {
1484 return tokenizers::Tokenizer::from_file(&tokenizer_json)
1485 .map_err(|e| FerrumError::model(format!("load tokenizer.json: {e}")));
1486 }
1487
1488 let vocab_path = dir.join("vocab.json");
1490 let merges_path = dir.join("merges.txt");
1491
1492 if !vocab_path.exists() || !merges_path.exists() {
1493 return Err(FerrumError::model(
1494 "tokenizer.json not found, and vocab.json + merges.txt not found either",
1495 ));
1496 }
1497
1498 let vocab_data = std::fs::read_to_string(&vocab_path)
1499 .map_err(|e| FerrumError::model(format!("read vocab.json: {e}")))?;
1500 let vocab: HashMap<String, u32> = serde_json::from_str(&vocab_data)
1501 .map_err(|e| FerrumError::model(format!("parse vocab.json: {e}")))?;
1502
1503 let merges_data = std::fs::read_to_string(&merges_path)
1504 .map_err(|e| FerrumError::model(format!("read merges.txt: {e}")))?;
1505 let merges: Vec<(String, String)> = merges_data
1506 .lines()
1507 .skip(1) .filter(|line| !line.is_empty())
1509 .filter_map(|line| {
1510 let parts: Vec<&str> = line.splitn(2, ' ').collect();
1511 if parts.len() == 2 {
1512 Some((parts[0].to_string(), parts[1].to_string()))
1513 } else {
1514 None
1515 }
1516 })
1517 .collect();
1518
1519 let bpe = tokenizers::models::bpe::BPE::from_file(
1520 vocab_path.to_str().unwrap(),
1521 merges_path.to_str().unwrap(),
1522 )
1523 .build()
1524 .map_err(|e| FerrumError::model(format!("build BPE: {e}")))?;
1525
1526 let tokenizer = tokenizers::Tokenizer::new(bpe);
1527 Ok(tokenizer)
1528}
1529
1530fn logits_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
1532 let logits = if logits.dims().len() == 3 {
1533 logits
1534 .squeeze(0)
1535 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1536 .squeeze(0)
1537 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1538 } else if logits.dims().len() == 2 {
1539 logits
1540 .squeeze(0)
1541 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1542 } else {
1543 logits.clone()
1544 };
1545
1546 logits
1547 .to_vec1()
1548 .map_err(|e| FerrumError::model(format!("logits to_vec1: {e}")))
1549}
1550
1551pub fn sample_token(
1559 logits: &[f32],
1560 temperature: f32,
1561 top_k: usize,
1562 _repetition_penalty: f32,
1563) -> u32 {
1564 if temperature < 0.01 {
1565 return argmax(logits);
1566 }
1567
1568 let vocab = logits.len();
1569
1570 let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
1572
1573 let mut filtered = scaled.clone();
1575 if top_k > 0 && top_k < vocab {
1576 let mut sorted = scaled.clone();
1577 sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1578 let threshold = sorted[top_k - 1];
1579 for v in &mut filtered {
1580 if *v < threshold {
1581 *v = f32::NEG_INFINITY;
1582 }
1583 }
1584 }
1585
1586 const TOP_P: f32 = 0.9;
1588 {
1589 let mut indices: Vec<usize> = (0..vocab).collect();
1590 indices.sort_unstable_by(|&a, &b| {
1591 filtered[b]
1592 .partial_cmp(&filtered[a])
1593 .unwrap_or(std::cmp::Ordering::Equal)
1594 });
1595
1596 let max_val = filtered[indices[0]];
1598 let exp_sorted: Vec<f32> = indices
1599 .iter()
1600 .map(|&i| (filtered[i] - max_val).exp())
1601 .collect();
1602 let sum: f32 = exp_sorted.iter().sum();
1603 let probs_sorted: Vec<f32> = exp_sorted.iter().map(|e| e / sum).collect();
1604
1605 let mut cumsum = 0.0f32;
1607 let mut cutoff_idx = vocab;
1608 for (i, &p) in probs_sorted.iter().enumerate() {
1609 cumsum += p;
1610 if cumsum > TOP_P {
1611 cutoff_idx = i + 1;
1612 break;
1613 }
1614 }
1615
1616 for &idx in &indices[cutoff_idx..] {
1618 filtered[idx] = f32::NEG_INFINITY;
1619 }
1620 }
1621
1622 let max_val = filtered.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1624 let exps: Vec<f32> = filtered.iter().map(|&v| (v - max_val).exp()).collect();
1625 let sum: f32 = exps.iter().sum();
1626 let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
1627
1628 let r = rand_f32();
1630 let mut cumulative = 0.0f32;
1631 for (i, &p) in probs.iter().enumerate() {
1632 cumulative += p;
1633 if cumulative >= r {
1634 return i as u32;
1635 }
1636 }
1637 argmax(&probs)
1639}
1640
1641fn argmax(v: &[f32]) -> u32 {
1642 v.iter()
1643 .enumerate()
1644 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1645 .map(|(i, _)| i as u32)
1646 .unwrap_or(0)
1647}
1648
1649fn rand_f32() -> f32 {
1651 use std::sync::atomic::{AtomicU64, Ordering};
1652 static COUNTER: AtomicU64 = AtomicU64::new(0);
1653
1654 let seed = std::time::SystemTime::now()
1655 .duration_since(std::time::UNIX_EPOCH)
1656 .unwrap_or_default()
1657 .subsec_nanos() as u64;
1658 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1659
1660 let state = seed
1661 .wrapping_add(count)
1662 .wrapping_mul(1103515245)
1663 .wrapping_add(12345);
1664 (state as f32) / (u64::MAX as f32)
1665}
1666
1667#[derive(Clone, Debug)]
1670#[allow(dead_code)]
1671struct DummyTtsCache;
1672
1673impl ferrum_interfaces::KvCacheHandle for DummyTtsCache {
1674 fn block_table(&self) -> &ferrum_interfaces::BlockTable {
1675 static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
1676 std::sync::OnceLock::new();
1677 EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
1678 }
1679 fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
1680 unimplemented!()
1681 }
1682 fn as_any(&self) -> &dyn std::any::Any {
1683 self
1684 }
1685 fn device(&self) -> Device {
1686 Device::CPU
1687 }
1688 fn num_layers(&self) -> usize {
1689 0
1690 }
1691 fn num_heads(&self) -> usize {
1692 0
1693 }
1694 fn head_dim(&self) -> usize {
1695 0
1696 }
1697 fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1698 Ok(None)
1699 }
1700 fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1701 Ok(None)
1702 }
1703 fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
1704 Ok(Arc::new(self.clone()))
1705 }
1706 fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
1707 ferrum_interfaces::CacheHandleStats {
1708 memory_bytes: 0,
1709 blocks_allocated: 0,
1710 tokens_stored: 0,
1711 utilization: 0.0,
1712 last_access: std::time::Instant::now(),
1713 }
1714 }
1715 fn is_valid(&self) -> bool {
1716 true
1717 }
1718 fn cache_id(&self) -> String {
1719 "tts_dummy".to_string()
1720 }
1721}
1722
1723#[async_trait]
1724impl ModelExecutor for TtsModelExecutor {
1725 fn info(&self) -> &ModelInfo {
1726 &self.info
1727 }
1728
1729 async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
1730 Err(FerrumError::model(
1731 "TTS uses synthesize(), not prefill/decode",
1732 ))
1733 }
1734
1735 async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
1736 Err(FerrumError::model(
1737 "TTS uses synthesize(), not prefill/decode",
1738 ))
1739 }
1740
1741 fn capabilities(&self) -> ExecutorCapabilities {
1742 ExecutorCapabilities {
1743 max_batch_size: 1,
1744 max_sequence_length: self.info.max_sequence_length,
1745 attention_mechanisms: vec![AttentionType::GroupedQuery],
1746 supports_dynamic_batching: false,
1747 supports_continuous_batching: false,
1748 supports_speculative_decoding: false,
1749 supports_tensor_parallelism: false,
1750 supports_pipeline_parallelism: false,
1751 supported_dtypes: vec![DataType::FP32, DataType::BF16],
1752 supported_devices: vec![self.info.device.clone()],
1753 memory_requirements: MemoryRequirements {
1754 parameter_memory: 0,
1755 activation_memory_per_token: 0,
1756 kv_cache_memory_per_token: 0,
1757 overhead_memory: 0,
1758 },
1759 }
1760 }
1761
1762 fn release_cache(&self, _: &str) {}
1763
1764 fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
1765 common::default_executor_status()
1766 }
1767}