Skip to main content

ferrum_models/executor/
tts_executor.rs

1//! Qwen3-TTS Executor — text-to-speech pipeline wiring Talker LM + Vocoder.
2//!
3//! Implements: text tokenization, autoregressive codec token generation,
4//! SubTalker code prediction (TODO), vocoder waveform synthesis.
5
6#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
7
8use std::collections::HashMap;
9use std::sync::{Arc, OnceLock};
10
11use async_trait::async_trait;
12use candle_core::{DType, Device as CandleDevice, Tensor};
13use candle_nn::VarBuilder;
14use ferrum_interfaces::{
15    model_executor::{
16        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
17        PrefillInput, PrefillOutput,
18    },
19    ModelExecutor, TensorRef,
20};
21use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
22use tracing::info;
23
24use super::common;
25use crate::multimodal::qwen3_tts::{Qwen3TTSTalker, SubTalker, TalkerConfig};
26use crate::multimodal::qwen3_tts_backbone::TalkerBackboneBackend;
27use crate::multimodal::qwen3_tts_vocoder::{Qwen3TTSVocoder, VocoderConfig};
28use crate::multimodal::speaker_encoder::{mel_spectrogram_speaker_encoder, SpeakerEncoder};
29use crate::multimodal::speech_tokenizer_encoder::SpeechTokenizerEncoder;
30use ferrum_quantization::NativeSafetensorsLoader;
31
32/// Install `Qwen3TtsTalker`/`SubTalker` Backend<CudaBackend> overrides so
33/// the transformer stack runs via ferrum-kernels cuBLAS + CUDA kernels
34/// instead of the broken fused-on-Linux CPU fallback.
35#[cfg(feature = "cuda")]
36fn install_cuda_backend_overrides(
37    cfg: &TalkerConfig,
38    model_dir: &std::path::Path,
39    talker: &mut Qwen3TTSTalker,
40    sub_talker: &mut SubTalker,
41) -> Result<()> {
42    use ferrum_kernels::backend::cuda::CudaBackend;
43    let loader: NativeSafetensorsLoader<CudaBackend> = NativeSafetensorsLoader::open(model_dir)?;
44    let talker_bb = TalkerBackboneBackend::<CudaBackend>::new(cfg, &loader)?;
45    talker.set_backend_override(Box::new(talker_bb));
46    let sub_bb = TalkerBackboneBackend::<CudaBackend>::new_code_predictor(cfg, &loader)?;
47    sub_talker.set_backend_override(Box::new(sub_bb));
48    Ok(())
49}
50
51// ── Constants ────────────────────────────────────────────────────────────
52
53const SAMPLE_RATE: usize = 24000;
54const MAX_CODEC_TOKENS: usize = 2000;
55
56/// Sampling parameters for codec token generation.
57/// FERRUM_TTS_TEMP env var overrides (0.0 = greedy, 0.9 = default sampling)
58const TEMPERATURE: f32 = 0.9;
59const TOP_K: usize = 50;
60const REPETITION_PENALTY: f32 = 1.05;
61
62#[derive(Debug, Clone, PartialEq)]
63struct TtsRuntimeEnv {
64    tts_temperature: f32,
65    st_temperature: Option<f32>,
66    ref_pcm: Option<String>,
67    ref_codes: Option<String>,
68    min_frames: Option<usize>,
69}
70
71impl TtsRuntimeEnv {
72    fn from_env() -> Self {
73        Self::from_env_vars(std::env::vars())
74    }
75
76    fn from_env_vars<I, K, V>(vars: I) -> Self
77    where
78        I: IntoIterator<Item = (K, V)>,
79        K: AsRef<str>,
80        V: Into<String>,
81    {
82        let mut tts_temperature = None;
83        let mut st_temperature = None;
84        let mut ref_pcm = None;
85        let mut ref_codes = None;
86        let mut min_frames = None;
87
88        for (key, value) in vars {
89            let value = value.into();
90            match key.as_ref() {
91                "FERRUM_TTS_TEMP" => tts_temperature = value.parse::<f32>().ok(),
92                "FERRUM_ST_TEMP" => st_temperature = value.parse::<f32>().ok(),
93                "FERRUM_REF_PCM" => ref_pcm = Some(value),
94                "FERRUM_REF_CODES" => ref_codes = Some(value),
95                "FERRUM_TTS_MIN_FRAMES" => min_frames = value.parse::<usize>().ok(),
96                _ => {}
97            }
98        }
99
100        Self {
101            tts_temperature: tts_temperature.unwrap_or(TEMPERATURE),
102            st_temperature,
103            ref_pcm,
104            ref_codes,
105            min_frames,
106        }
107    }
108
109    fn st_temperature(&self) -> f32 {
110        self.st_temperature.unwrap_or(self.tts_temperature)
111    }
112}
113
114fn tts_runtime_env() -> &'static TtsRuntimeEnv {
115    static CONFIG: OnceLock<TtsRuntimeEnv> = OnceLock::new();
116    CONFIG.get_or_init(TtsRuntimeEnv::from_env)
117}
118
119fn tts_temperature() -> f32 {
120    tts_runtime_env().tts_temperature
121}
122
123fn st_temperature() -> f32 {
124    tts_runtime_env().st_temperature()
125}
126
127/// Qwen3-TTS executor: text-to-speech synthesis.
128pub struct TtsModelExecutor {
129    talker: Qwen3TTSTalker,
130    sub_talker: SubTalker,
131    vocoder: Qwen3TTSVocoder,
132    text_tokenizer: tokenizers::Tokenizer,
133    config: TalkerConfig,
134    info: ModelInfo,
135    speaker_encoder: Option<SpeakerEncoder>,
136    speech_tokenizer_encoder: Option<SpeechTokenizerEncoder>,
137}
138
139impl TtsModelExecutor {
140    /// Load from model directory containing:
141    /// - config.json (TalkerConfig)
142    /// - model.safetensors (Talker weights)
143    /// - speech_tokenizer/model.safetensors (Vocoder weights)
144    /// - tokenizer_config.json + vocab.json + merges.txt (text tokenizer)
145    pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
146        let dir = std::path::Path::new(model_path);
147
148        // Parse TalkerConfig from config.json
149        let config_json: serde_json::Value = {
150            let config_path = dir.join("config.json");
151            let data = std::fs::read_to_string(&config_path)
152                .map_err(|e| FerrumError::model(format!("read config.json: {e}")))?;
153            serde_json::from_str(&data)
154                .map_err(|e| FerrumError::model(format!("parse config.json: {e}")))?
155        };
156        let config = TalkerConfig::from_json(&config_json)?;
157
158        // Load text tokenizer from vocab.json + merges.txt
159        let text_tokenizer = load_bpe_tokenizer(dir)?;
160
161        // Load Talker weights from model.safetensors (or sharded)
162        let talker_weights = find_safetensor_files(dir, "model")?;
163        let talker_vb = unsafe {
164            VarBuilder::from_mmaped_safetensors(&talker_weights, dtype, &device)
165                .map_err(|e| FerrumError::model(format!("load talker weights: {e}")))?
166        };
167        let mut talker = Qwen3TTSTalker::load(&config, talker_vb.clone(), device.clone())?;
168
169        // Load SubTalker (code predictor) from same weights file
170        let mut sub_talker = SubTalker::load(&config, talker_vb.clone(), device.clone())?;
171
172        // Load Speaker Encoder (for voice cloning, base models only)
173        let spk_enc_dim = config_json
174            .get("speaker_encoder_config")
175            .and_then(|c| c.get("enc_dim"))
176            .and_then(|v| v.as_u64())
177            .unwrap_or(1024) as usize;
178        let speaker_encoder =
179            SpeakerEncoder::load_with_dim(talker_vb.pp("speaker_encoder"), spk_enc_dim)
180                .map_err(|e| {
181                    tracing::warn!("Speaker encoder not available: {e}");
182                    e
183                })
184                .ok();
185
186        // Load Vocoder weights from speech_tokenizer/model.safetensors
187        let vocoder_dir = dir.join("speech_tokenizer");
188        let vocoder_weights = find_safetensor_files(&vocoder_dir, "model")?;
189        let vocoder_vb = unsafe {
190            VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &device)
191                .map_err(|e| FerrumError::model(format!("load vocoder weights: {e}")))?
192        };
193        let vocoder_config = VocoderConfig::default();
194        let vocoder = Qwen3TTSVocoder::load(&vocoder_config, vocoder_vb.clone())?;
195
196        // Load Speech Tokenizer Encoder on CPU — Metal float32 accumulation order
197        // causes transformer output divergence that amplifies through RVQ codebook lookup.
198        // CPU is exact and encoder only runs once per reference audio.
199        let speech_tokenizer_encoder = if vocoder_dir.join("config.json").exists() {
200            let cpu_vb = unsafe {
201                VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &CandleDevice::Cpu)
202                    .map_err(|e| FerrumError::model(format!("load encoder cpu: {e}")))?
203            };
204            SpeechTokenizerEncoder::load(cpu_vb.pp("encoder"), CandleDevice::Cpu)
205                .map_err(|e| {
206                    tracing::warn!("Speech tokenizer encoder not available: {e}");
207                    e
208                })
209                .ok()
210        } else {
211            None
212        };
213
214        // Install a Backend<B>-backed transformer for the Talker/SubTalker
215        // when running on CUDA. The legacy `ferrum_kernels::attention::FusedTransformer`
216        // CUDA module is a stub and its Linux CPU fallback uses naive fp64
217        // matmul — the CUDA-only voice-clone regression traces back to that
218        // path accumulating precision drift through the 20-layer decoder.
219        // Routing through LlamaFamilyModel<CudaBackend> gives us cuBLAS +
220        // ferrum-kernels for the transformer stack while keeping candle
221        // embeddings / projection / codec_head unchanged.
222        #[cfg(feature = "cuda")]
223        if matches!(&device, CandleDevice::Cuda(_)) {
224            match install_cuda_backend_overrides(&config, dir, &mut talker, &mut sub_talker) {
225                Ok(()) => {
226                    tracing::info!(
227                        "TtsModelExecutor: Backend<CudaBackend> installed for Talker + SubTalker"
228                    );
229                }
230                Err(e) => {
231                    tracing::warn!(
232                        "TtsModelExecutor: Backend<CudaBackend> install failed ({e}); \
233                         falling back to candle/fused path (CUDA voice-clone may produce garbage)"
234                    );
235                }
236            }
237        }
238
239        let info = ModelInfo {
240            model_id: ferrum_types::ModelId(model_path.to_string()),
241            model_type: ModelType::Custom("qwen3-tts".to_string()),
242            hidden_size: config.hidden_size,
243            vocab_size: config.vocab_size,
244            num_layers: config.num_hidden_layers,
245            num_heads: config.num_attention_heads,
246            num_kv_heads: config.num_key_value_heads,
247            num_parameters: 0,
248            max_sequence_length: config.max_position_embeddings,
249            device: match &device {
250                CandleDevice::Cpu => Device::CPU,
251                CandleDevice::Cuda(_) => Device::CUDA(0),
252                #[cfg(any(target_os = "macos", target_os = "ios"))]
253                CandleDevice::Metal(_) => Device::Metal,
254                #[cfg(not(any(target_os = "macos", target_os = "ios")))]
255                CandleDevice::Metal(_) => Device::CPU,
256            },
257            dtype: match dtype {
258                DType::F32 => DataType::FP32,
259                DType::F16 => DataType::FP16,
260                DType::BF16 => DataType::BF16,
261                _ => DataType::FP32,
262            },
263            version: None,
264            license: None,
265            metadata: HashMap::new(),
266        };
267
268        info!(
269            "TtsModelExecutor: {} (hidden={}, layers={}, codec_groups={})",
270            model_path, config.hidden_size, config.num_hidden_layers, config.num_code_groups,
271        );
272
273        Ok(Self {
274            talker,
275            sub_talker,
276            vocoder,
277            text_tokenizer,
278            config,
279            info,
280            speaker_encoder,
281            speech_tokenizer_encoder,
282        })
283    }
284
285    /// Synthesize speech from text.
286    ///
287    /// Returns PCM samples at 24kHz as Vec<f32>.
288    ///
289    /// Prompt structure (matches Python/qwen3-tts-rs):
290    ///   Prefill: [role_prefix(3)] + [tts_text_prefix(6) + codec_prefix(6)] + [first_text + codec_bos]
291    ///   Trailing: text_projection(remaining_text + tts_eos) — added per decode step
292    pub fn synthesize(&mut self, text: &str, language: &str) -> Result<Vec<f32>> {
293        self.talker.reset();
294
295        let device = self.talker.device().clone();
296
297        // 1. Tokenize text (raw content only, no chat template)
298        let encoding = self
299            .text_tokenizer
300            .encode(text, false)
301            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
302        let content_ids: Vec<u32> = encoding.get_ids().to_vec();
303
304        if content_ids.is_empty() {
305            return Err(FerrumError::model("empty text after tokenization"));
306        }
307
308        info!("TTS: content tokens = {}", content_ids.len());
309
310        let codec_eos = self.config.codec_eos_token_id;
311        let tts_pad = self.config.tts_pad_token_id;
312        let tts_bos = self.config.tts_bos_token_id;
313        let tts_eos = self.config.tts_eos_token_id;
314
315        // Helper: embed text token IDs through text_embedding + text_projection
316        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
317            let t = Tensor::new(ids, &device)
318                .and_then(|t| t.unsqueeze(0))
319                .map_err(|e| FerrumError::model(format!("text ids: {e}")))?;
320            self.talker.embed_text(&t)
321        };
322        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
323            let t = Tensor::new(ids, &device)
324                .and_then(|t| t.unsqueeze(0))
325                .map_err(|e| FerrumError::model(format!("codec ids: {e}")))?;
326            self.talker.embed_codec(&t)
327        };
328
329        // 2. Build prefill (matching qwen3-tts-rs prefill_custom_voice)
330
331        // Role prefix: text_projection([<|im_start|>, assistant, \n])
332        // These are fixed special token IDs from the tokenizer
333        let im_start_id = 151644u32; // <|im_start|>
334        let assistant_id = 77091u32; // "assistant"
335        let newline_id = 198u32; // "\n"
336        let role_prefix_ids = [im_start_id, assistant_id, newline_id];
337
338        info!(
339            "TTS: role_prefix={} content={} tokens",
340            role_prefix_ids.len(),
341            content_ids.len()
342        );
343
344        // Role prefix embedding (text_projection)
345        let role_embed = embed_text_ids(&role_prefix_ids)?;
346
347        // Codec prefix: [think, think_bos, lang, think_eos, speaker, pad]
348        let resolved_lang = if language == "auto" {
349            "chinese"
350        } else {
351            language
352        };
353        let language_id = self
354            .config
355            .codec_language_id
356            .get(&resolved_lang.to_lowercase());
357        let codec_prefix_ids = if let Some(&lang_id) = language_id {
358            vec![
359                self.config.codec_think_id,
360                self.config.codec_think_bos_id,
361                lang_id,
362                self.config.codec_think_eos_id,
363            ]
364        } else {
365            vec![
366                self.config.codec_nothink_id,
367                self.config.codec_think_bos_id,
368                self.config.codec_think_eos_id,
369            ]
370        };
371        // Codec sequence: [think, think_bos, lang, think_eos, SPEAKER, pad, bos]
372        // Speaker: use Vivian (3065) for Chinese, Ryan (3061) for English
373        let speaker_token = if resolved_lang == "chinese" {
374            3065u32
375        } else {
376            3061u32
377        };
378        let codec_full = {
379            let mut v = codec_prefix_ids.clone();
380            v.push(speaker_token);
381            v.push(self.config.codec_pad_id);
382            v.push(self.config.codec_bos_id);
383            v
384        };
385        let codec_embed = embed_codec_ids(&codec_full)?;
386
387        // tts_text_prefix: [tts_pad * (codec_len-1), tts_bos]
388        let n_codec = codec_full.len();
389        let mut tts_prefix_ids = vec![tts_pad; n_codec - 1];
390        tts_prefix_ids.push(tts_bos);
391        let tts_prefix_embed = embed_text_ids(&tts_prefix_ids)?;
392
393        // Sum: tts_prefix + codec_prefix[:-1]
394        let codec_first = codec_embed
395            .narrow(1, 0, n_codec - 1)
396            .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
397        // Actually we need codec_first to have same length as tts_prefix
398        // tts_prefix has n_codec elements, codec_first has n_codec-1
399        // Let me re-read the reference... codec_embed has 6 tokens [think..pad,bos], tts_prefix has 6 [pad*5,bos]
400        // They sum first 6 codec with first 6 tts_prefix, then codec_bos is separate
401        // codec_full has [think, think_bos, lang, think_eos, speaker, pad, bos] = 7 tokens
402        // tts_prefix: [pad*5, bos] overlaid on codec[0:6] (first 6, excluding last bos)
403        // Then: first_text + codec_bos (last token) summed
404        let n_prefix = n_codec - 1; // 6: everything except codec_bos
405        let codec_prefix_part = codec_embed
406            .narrow(1, 0, n_prefix)
407            .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
408
409        // tts text prefix: [pad * (n_prefix-1), bos]
410        let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
411        tts_text_prefix_ids.push(tts_bos);
412        let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
413
414        let codec_hidden = (&tts_text_embed + &codec_prefix_part)
415            .map_err(|e| FerrumError::model(format!("prefix sum: {e}")))?;
416
417        // codec_bos is the last element of codec_full
418        let codec_bos_embed = codec_embed
419            .narrow(1, n_prefix, 1)
420            .map_err(|e| FerrumError::model(format!("codec bos: {e}")))?;
421
422        // First text token + codec_bos (summed)
423        let first_text_combined = if !content_ids.is_empty() {
424            let first_text_embed = embed_text_ids(&content_ids[..1])?;
425            (&first_text_embed + &codec_bos_embed)
426                .map_err(|e| FerrumError::model(format!("first text+bos: {e}")))?
427        } else {
428            codec_bos_embed.clone()
429        };
430
431        // Full prefill: [role_prefix, codec_hidden, first_text+codec_bos]
432        let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
433            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
434
435        let plen = prefill_embeds.dim(1).unwrap_or(0);
436        info!("TTS: prefill_len = {}", plen);
437        // Dump prefill input for comparison with reference
438        if let Ok(v) = prefill_embeds
439            .narrow(0, 0, 1)
440            .and_then(|t| t.narrow(1, 0, 1))
441            .and_then(|t| t.narrow(2, 0, 5))
442            .and_then(|t| t.flatten_all())
443            .and_then(|t| t.to_vec1::<f32>())
444        {
445            info!("  prefill_input pos0[:5] = {:?}", v);
446        }
447        if plen > 0 {
448            if let Ok(v) = prefill_embeds
449                .narrow(0, 0, 1)
450                .and_then(|t| t.narrow(1, plen - 1, 1))
451                .and_then(|t| t.narrow(2, 0, 5))
452                .and_then(|t| t.flatten_all())
453                .and_then(|t| t.to_vec1::<f32>())
454            {
455                info!("  prefill_input pos-1[:5] = {:?}", v);
456            }
457        }
458
459        // 3. Build trailing text: text_projection(remaining_content + tts_eos)
460        let mut trailing_ids: Vec<u32> = if content_ids.len() > 1 {
461            content_ids[1..].to_vec()
462        } else {
463            Vec::new()
464        };
465        trailing_ids.push(tts_eos);
466        let trailing_text_embeds = embed_text_ids(&trailing_ids)?;
467        let trailing_text_len = trailing_text_embeds
468            .dim(1)
469            .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
470        let tts_pad_embed = embed_text_ids(&[tts_pad])?;
471
472        info!("TTS: trailing_text_len = {}", trailing_text_len);
473
474        // 4. Prefill: forward through transformer
475        let mut hidden = self.talker.forward_step(&prefill_embeds)?;
476
477        // Get logits from last position
478        let current_logits = self.talker.logits(
479            &hidden
480                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
481                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
482        )?;
483
484        // 4. Autoregressive decode loop: generate codec token 0 per step
485        let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
486        let mut current_logits = current_logits;
487
488        // Token suppression: mask [vocab_size-1024, vocab_size) except EOS
489        let suppress_start = self.config.vocab_size.saturating_sub(1024);
490        let suppress_end = self.config.vocab_size;
491        let mut generated_tokens: Vec<u32> = Vec::new();
492
493        for step in 0..MAX_CODEC_TOKENS {
494            // Sample next codec token with suppression + repetition penalty
495            let mut logits_vec = logits_to_vec(&current_logits)?;
496            // Suppress special tokens
497            for i in suppress_start..suppress_end.min(logits_vec.len()) {
498                if i as u32 != codec_eos {
499                    logits_vec[i] = f32::NEG_INFINITY;
500                }
501            }
502            // Repetition penalty (matching Python's repetition_penalty=1.05)
503            for &prev_tok in &generated_tokens {
504                let idx = prev_tok as usize;
505                if idx < logits_vec.len() {
506                    if logits_vec[idx] > 0.0 {
507                        logits_vec[idx] /= REPETITION_PENALTY;
508                    } else {
509                        logits_vec[idx] *= REPETITION_PENALTY;
510                    }
511                }
512            }
513            let next_token =
514                sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
515
516            if step < 10 {
517                // Find argmax (greedy token)
518                let argmax_tok = logits_vec
519                    .iter()
520                    .enumerate()
521                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
522                    .map(|(i, v)| (i, *v))
523                    .unwrap_or((0, 0.0));
524                info!(
525                    "TOKEN step={} sampled={} argmax=({}, {:.2})",
526                    step, next_token, argmax_tok.0, argmax_tok.1
527                );
528            }
529
530            generated_tokens.push(next_token);
531
532            // Check for EOS
533            if next_token == codec_eos {
534                info!("TTS: codec EOS at step {}", step);
535                break;
536            }
537
538            // Get last hidden state from talker for SubTalker
539            let last_hidden = hidden
540                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
541                .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
542
543            // Embed first codec token
544            let token_tensor = Tensor::new(&[next_token], &device)
545                .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
546                .unsqueeze(0)
547                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
548            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
549
550            // SubTalker: predict remaining codec tokens 1..num_code_groups-1
551            let st_t0 = std::time::Instant::now();
552            let extra_codes = self.sub_talker.predict(
553                &last_hidden,
554                &first_codec_embed,
555                st_temperature(),
556                TOP_K,
557            )?;
558            if step == 0 {
559                info!(
560                    "  SubTalker: {:.1}ms",
561                    st_t0.elapsed().as_secs_f64() * 1000.0
562                );
563            }
564
565            let mut frame_codes = vec![next_token];
566            frame_codes.extend_from_slice(&extra_codes);
567            all_codec_tokens.push(frame_codes);
568
569            // Build combined embedding for next talker step:
570            // sum of all codec embeddings (token 0 from main talker + tokens 1-15 from sub-talker)
571            let mut combined_embed = first_codec_embed.clone();
572            for (i, &code) in extra_codes.iter().enumerate() {
573                let code_t = Tensor::new(&[code], &device)
574                    .and_then(|t| t.unsqueeze(0))
575                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
576                let sub_embed = code_t
577                    .apply(&self.sub_talker.codec_embeddings[i])
578                    .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
579                combined_embed = (combined_embed + sub_embed)
580                    .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
581            }
582
583            if step == 0 {
584                if let Ok(v) = combined_embed
585                    .flatten_all()
586                    .and_then(|t| t.narrow(0, 0, 5))
587                    .and_then(|t| t.to_vec1::<f32>())
588                {
589                    info!("STEP0 codec_sum[:5] = {:?} (before trailing)", v);
590                }
591            }
592            // Add trailing text embedding (guides generation toward target text)
593            // Python: inputs_embeds = codec_sum + trailing_text_hidden[:, gen_step]
594            // trailing_text = text_projection(remaining_text) + tts_eos
595            // For basic TTS, trailing covers all text tokens after the first
596            if step < trailing_text_len {
597                let trail = trailing_text_embeds
598                    .narrow(1, step, 1)
599                    .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
600                combined_embed = (combined_embed + trail)
601                    .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
602            } else {
603                combined_embed = (combined_embed + &tts_pad_embed)
604                    .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
605            }
606
607            // Debug: dump step 0 components
608            if step == 0 {
609                if let Ok(v) = first_codec_embed
610                    .flatten_all()
611                    .and_then(|t| t.narrow(0, 0, 5))
612                    .and_then(|t| t.to_vec1::<f32>())
613                {
614                    info!("STEP0 semantic[:5] = {:?}", v);
615                }
616                if let Ok(v) = combined_embed
617                    .flatten_all()
618                    .and_then(|t| t.narrow(0, 0, 5))
619                    .and_then(|t| t.to_vec1::<f32>())
620                {
621                    info!("STEP0 combined[:5] = {:?}", v);
622                }
623            }
624            // Forward combined embedding through talker
625            let tk_t0 = std::time::Instant::now();
626            hidden = self.talker.forward_step(&combined_embed)?;
627            current_logits = self.talker.logits(&hidden)?;
628            if step == 0 {
629                info!(
630                    "  Talker step: {:.1}ms",
631                    tk_t0.elapsed().as_secs_f64() * 1000.0
632                );
633            }
634        }
635
636        if all_codec_tokens.is_empty() {
637            return Err(FerrumError::model("no codec tokens generated"));
638        }
639
640        info!("TTS: generated {} codec frames", all_codec_tokens.len());
641
642        // 5. Build codec tensor [1, num_code_groups, T] for vocoder
643        let num_frames = all_codec_tokens.len();
644        let num_groups = self.config.num_code_groups;
645        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
646        for (t, frame) in all_codec_tokens.iter().enumerate() {
647            for (g, &code) in frame.iter().enumerate() {
648                flat_codes[g * num_frames + t] = code;
649            }
650        }
651
652        // Clamp codes to valid codebook range [0, codebook_size-1].
653        // Special tokens (codec_bos, codec_eos, etc.) are >= codebook_size and must be clamped.
654        let codebook_size = 2048u32;
655        for code in &mut flat_codes {
656            if *code >= codebook_size {
657                *code = 0; // replace special tokens with pad/silence
658            }
659        }
660
661        let codes_tensor = Tensor::new(&flat_codes[..], &device)
662            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
663            .reshape((1, num_groups, num_frames))
664            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
665
666        // 6. Vocoder: codec tokens → waveform
667        let waveform = self.vocoder.decode(&codes_tensor)?;
668
669        // Extract samples: [1, 1, samples] → Vec<f32>
670        let samples: Vec<f32> = waveform
671            .squeeze(0)
672            .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
673            .squeeze(0)
674            .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
675            .to_vec1()
676            .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
677
678        info!(
679            "TTS: waveform {} samples ({:.2}s @ {}Hz)",
680            samples.len(),
681            samples.len() as f64 / SAMPLE_RATE as f64,
682            SAMPLE_RATE,
683        );
684
685        Ok(samples)
686    }
687
688    /// Decode a chunk of codec frames to audio samples.
689    /// frames: Vec<Vec<u32>> where each inner Vec has num_code_groups elements.
690    fn decode_frames(&mut self, frames: &[Vec<u32>], device: &CandleDevice) -> Result<Vec<f32>> {
691        let num_frames = frames.len();
692        if num_frames == 0 {
693            return Ok(vec![]);
694        }
695        let num_groups = self.config.num_code_groups;
696        let codebook_size = 2048u32;
697
698        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
699        for (t, frame) in frames.iter().enumerate() {
700            for (g, &code) in frame.iter().take(num_groups).enumerate() {
701                flat_codes[g * num_frames + t] = if code >= codebook_size { 0 } else { code };
702            }
703        }
704
705        let codes_tensor = Tensor::new(&flat_codes[..], device)
706            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
707            .reshape((1, num_groups, num_frames))
708            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
709
710        let waveform = self.vocoder.decode(&codes_tensor)?;
711        waveform
712            .squeeze(0)
713            .and_then(|t| t.squeeze(0))
714            .and_then(|t| t.to_vec1())
715            .map_err(|e| FerrumError::model(format!("waveform extract: {e}")))
716    }
717
718    /// Streaming TTS: calls `on_chunk` with each audio chunk as soon as it's ready.
719    ///
720    /// Each chunk is `chunk_frames` codec frames decoded to audio (~800ms at default 10 frames).
721    /// First chunk arrives after `chunk_frames` decode steps (~2-3s for 0.6B).
722    pub fn synthesize_streaming<F: FnMut(usize, &[f32])>(
723        &mut self,
724        text: &str,
725        language: &str,
726        chunk_frames: usize,
727        mut on_chunk: F,
728    ) -> Result<Vec<Vec<f32>>> {
729        // Reuse existing synthesize setup (prefill + trailing text)
730        // but yield audio in chunks instead of all at once
731        self.talker.reset();
732        let device = self.talker.device().clone();
733
734        let encoding = self
735            .text_tokenizer
736            .encode(text, false)
737            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
738        let content_ids: Vec<u32> = encoding.get_ids().to_vec();
739        if content_ids.is_empty() {
740            return Err(FerrumError::model("empty text after tokenization"));
741        }
742
743        let codec_eos = self.config.codec_eos_token_id;
744        let tts_pad = self.config.tts_pad_token_id;
745        let tts_bos = self.config.tts_bos_token_id;
746        let tts_eos = self.config.tts_eos_token_id;
747
748        // Build embeddings (same as synthesize)
749        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
750            let t = Tensor::new(ids, &device)
751                .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
752                .unsqueeze(0)
753                .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
754            self.talker.embed_text(&t)
755        };
756        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
757            let t = Tensor::new(ids, &device)
758                .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
759                .unsqueeze(0)
760                .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
761            self.talker.embed_codec(&t)
762        };
763
764        // Codec/text prefix (same as synthesize)
765        let resolved_lang = if language.eq_ignore_ascii_case("auto") {
766            "chinese"
767        } else {
768            language
769        };
770        let language_id = self
771            .config
772            .codec_language_id
773            .get(&resolved_lang.to_lowercase());
774        let codec_prefix_ids = if let Some(&lang_id) = language_id {
775            vec![
776                self.config.codec_think_id,
777                self.config.codec_think_bos_id,
778                lang_id,
779                self.config.codec_think_eos_id,
780            ]
781        } else {
782            vec![
783                self.config.codec_nothink_id,
784                self.config.codec_think_bos_id,
785                self.config.codec_think_eos_id,
786            ]
787        };
788        let speaker_token = if resolved_lang == "chinese" {
789            3065u32
790        } else {
791            3061u32
792        };
793        let mut codec_ids = codec_prefix_ids;
794        codec_ids.push(speaker_token);
795        codec_ids.push(self.config.codec_pad_id);
796        codec_ids.push(self.config.codec_bos_id);
797        let codec_embed = embed_codec_ids(&codec_ids)?;
798        let n_codec = codec_embed
799            .dim(1)
800            .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
801        let n_prefix = n_codec - 1;
802        let codec_prefix_part = codec_embed
803            .narrow(1, 0, n_prefix)
804            .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
805
806        let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
807        tts_text_prefix_ids.push(tts_bos);
808        let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
809        let codec_hidden = (&tts_text_embed + &codec_prefix_part)
810            .map_err(|e| FerrumError::model(format!("sum: {e}")))?;
811        let codec_bos_embed = codec_embed
812            .narrow(1, n_prefix, 1)
813            .map_err(|e| FerrumError::model(format!("bos: {e}")))?;
814
815        let role_ids: &[u32] = &[151644, 77091, 198]; // im_start, assistant, \n
816        let role_embed = embed_text_ids(role_ids)?;
817
818        let first_text_combined = if !content_ids.is_empty() {
819            let first_text_embed = embed_text_ids(&content_ids[..1])?;
820            (&first_text_embed + &codec_bos_embed)
821                .map_err(|e| FerrumError::model(format!("first: {e}")))?
822        } else {
823            codec_bos_embed.clone()
824        };
825
826        let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
827            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
828
829        // Trailing text
830        let trailing_text_embeds = if content_ids.len() > 1 {
831            let remaining = embed_text_ids(&content_ids[1..])?;
832            let eos = embed_text_ids(&[tts_eos])?;
833            Tensor::cat(&[&remaining, &eos], 1)
834                .map_err(|e| FerrumError::model(format!("trailing: {e}")))?
835        } else {
836            embed_text_ids(&[tts_eos])?
837        };
838        let trailing_text_len = trailing_text_embeds
839            .dim(1)
840            .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
841        let tts_pad_embed = embed_text_ids(&[tts_pad])?;
842
843        // Prefill
844        let mut hidden = self.talker.forward_step(&prefill_embeds)?;
845        let mut current_logits = self.talker.logits(
846            &hidden
847                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
848                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
849        )?;
850
851        // Streaming decode loop
852        let suppress_start = self.config.vocab_size.saturating_sub(1024);
853        let suppress_end = self.config.vocab_size;
854        let mut generated_tokens: Vec<u32> = Vec::new();
855        let mut frame_buffer: Vec<Vec<u32>> = Vec::new();
856        let mut audio_chunks: Vec<Vec<f32>> = Vec::new();
857
858        for step in 0..MAX_CODEC_TOKENS {
859            let mut logits_vec = logits_to_vec(&current_logits)?;
860            for i in suppress_start..suppress_end.min(logits_vec.len()) {
861                if i as u32 != codec_eos {
862                    logits_vec[i] = f32::NEG_INFINITY;
863                }
864            }
865            for &prev_tok in &generated_tokens {
866                let idx = prev_tok as usize;
867                if idx < logits_vec.len() {
868                    if logits_vec[idx] > 0.0 {
869                        logits_vec[idx] /= REPETITION_PENALTY;
870                    } else {
871                        logits_vec[idx] *= REPETITION_PENALTY;
872                    }
873                }
874            }
875            let next_token =
876                sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
877            generated_tokens.push(next_token);
878
879            if next_token == codec_eos {
880                info!("TTS streaming: EOS at step {}", step);
881                break;
882            }
883
884            let last_hidden = hidden
885                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
886                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
887            let token_tensor = Tensor::new(&[next_token], &device)
888                .and_then(|t| t.unsqueeze(0))
889                .map_err(|e| FerrumError::model(format!("tok: {e}")))?;
890            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
891
892            let extra_codes = self.sub_talker.predict(
893                &last_hidden,
894                &first_codec_embed,
895                st_temperature(),
896                TOP_K,
897            )?;
898
899            let mut frame = vec![next_token];
900            frame.extend_from_slice(&extra_codes);
901            frame_buffer.push(frame);
902
903            // Emit chunk when buffer is full
904            if frame_buffer.len() >= chunk_frames {
905                let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
906                on_chunk(audio_chunks.len(), &chunk_audio);
907                audio_chunks.push(chunk_audio);
908                frame_buffer.clear();
909            }
910
911            // Build combined embed for next step
912            let mut combined_embed = first_codec_embed.clone();
913            for (i, &code) in extra_codes.iter().enumerate() {
914                let code_t = Tensor::new(&[code], &device)
915                    .and_then(|t| t.unsqueeze(0))
916                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
917                let sub_embed = code_t
918                    .apply(&self.sub_talker.codec_embeddings[i])
919                    .map_err(|e| FerrumError::model(format!("sub: {e}")))?;
920                combined_embed = (combined_embed + sub_embed)
921                    .map_err(|e| FerrumError::model(format!("add: {e}")))?;
922            }
923            if step < trailing_text_len {
924                let trail = trailing_text_embeds
925                    .narrow(1, step, 1)
926                    .map_err(|e| FerrumError::model(format!("trail: {e}")))?;
927                combined_embed = (combined_embed + trail)
928                    .map_err(|e| FerrumError::model(format!("add trail: {e}")))?;
929            } else {
930                combined_embed = (combined_embed + &tts_pad_embed)
931                    .map_err(|e| FerrumError::model(format!("pad: {e}")))?;
932            }
933
934            hidden = self.talker.forward_step(&combined_embed)?;
935            current_logits = self.talker.logits(&hidden)?;
936        }
937
938        // Flush remaining frames
939        if !frame_buffer.is_empty() {
940            let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
941            on_chunk(audio_chunks.len(), &chunk_audio);
942            audio_chunks.push(chunk_audio);
943        }
944
945        Ok(audio_chunks)
946    }
947
948    /// Get the output sample rate.
949    pub fn sample_rate(&self) -> usize {
950        SAMPLE_RATE
951    }
952
953    pub fn config(&self) -> &TalkerConfig {
954        &self.config
955    }
956
957    /// Synthesize speech with voice cloning from a reference audio.
958    ///
959    /// Uses ICL (in-context learning) prompting: the reference audio is
960    /// encoded to codec tokens and prepended to the generation prompt,
961    /// along with a speaker embedding extracted via ECAPA-TDNN.
962    ///
963    /// Returns PCM samples at 24kHz as Vec<f32>.
964    pub fn synthesize_voice_clone(
965        &mut self,
966        text: &str,
967        language: &str,
968        ref_audio_path: &str,
969        ref_text: &str,
970    ) -> Result<Vec<f32>> {
971        let device = self.talker.device().clone();
972
973        // Step 1: Load and process reference audio at 24kHz
974        let ref_pcm = if let Some(path) = tts_runtime_env().ref_pcm.as_deref() {
975            let data = std::fs::read(&path)
976                .map_err(|e| FerrumError::model(format!("read ref pcm: {e}")))?;
977            let pcm: Vec<f32> = data
978                .chunks(4)
979                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
980                .collect();
981            info!("Loaded ref PCM override: {} samples", pcm.len());
982            pcm
983        } else {
984            crate::audio_processor::load_audio_at_rate(ref_audio_path, 24000)?
985        };
986        info!(
987            "TTS voice clone: loaded ref audio {} samples ({:.2}s)",
988            ref_pcm.len(),
989            ref_pcm.len() as f64 / 24000.0
990        );
991
992        let t0 = std::time::Instant::now();
993        // Step 2: Extract speaker embedding via ECAPA-TDNN
994        let speaker_encoder = self
995            .speaker_encoder
996            .as_ref()
997            .ok_or_else(|| FerrumError::model("speaker encoder not loaded"))?;
998        let mel = mel_spectrogram_speaker_encoder(&ref_pcm);
999        let n_mel_frames = mel.len() / 128;
1000        // mel_spectrogram_speaker_encoder returns [T, 128] row-major
1001        // forward() internally transposes [1, T, 128] → [1, 128, T]
1002        let mel_tensor = Tensor::from_vec(mel, (1, n_mel_frames, 128), &device)
1003            .map_err(|e| FerrumError::model(format!("mel tensor: {e}")))?;
1004        let spk_embed = speaker_encoder.forward(&mel_tensor)?;
1005        info!(
1006            "Step 2 (speaker embed): {:.1}ms",
1007            t0.elapsed().as_secs_f64() * 1000.0
1008        );
1009        // spk_embed shape: [enc_dim] -> reshape to [1, 1, hidden_size]
1010        let spk_embed = spk_embed
1011            .unsqueeze(0)
1012            .map_err(|e| FerrumError::model(format!("spk unsqueeze(0): {e}")))?
1013            .unsqueeze(0)
1014            .map_err(|e| FerrumError::model(format!("spk unsqueeze(0) 2: {e}")))?;
1015
1016        let t1 = std::time::Instant::now();
1017        // Step 3: Encode reference audio to codec tokens (ICL)
1018        let speech_enc = self
1019            .speech_tokenizer_encoder
1020            .as_ref()
1021            .ok_or_else(|| FerrumError::model("speech tokenizer encoder not loaded"))?;
1022        // Allow pre-computed codec tokens for debugging (FERRUM_REF_CODES=/path/to/codes.bin)
1023        let ref_codes = if let Some(path) = tts_runtime_env().ref_codes.as_deref() {
1024            let data = std::fs::read(&path)
1025                .map_err(|e| FerrumError::model(format!("read ref codes: {e}")))?;
1026            let u32s: Vec<u32> = data
1027                .chunks(4)
1028                .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1029                .collect();
1030            let ncb = self.config.num_code_groups;
1031            let nframes = u32s.len() / ncb;
1032            info!(
1033                "Loaded pre-computed ref codes: {} frames from {}",
1034                nframes, path
1035            );
1036            u32s.chunks(ncb).map(|c| c.to_vec()).collect()
1037        } else {
1038            let codes = speech_enc.encode(&ref_pcm)?;
1039            info!(
1040                "Step 3 (speech tokenizer): {:.1}ms",
1041                t1.elapsed().as_secs_f64() * 1000.0
1042            );
1043            codes
1044        };
1045        let ref_frames = ref_codes.len();
1046        info!(
1047            "TTS voice clone: ref_frames={}, spk_embed loaded",
1048            ref_frames
1049        );
1050        // Debug: dump first 5 codec frames for comparison with Python
1051        for i in 0..ref_frames.min(5) {
1052            info!("  rust codec frame {}: {:?}", i, &ref_codes[i]);
1053        }
1054
1055        info!(
1056            "Step 3 (speech tokenizer): {:.1}ms",
1057            t1.elapsed().as_secs_f64() * 1000.0
1058        );
1059        let t2 = std::time::Instant::now();
1060        // Step 4: Tokenize target text with chat template
1061        let chat_text = format!("<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n");
1062        let encoding = self
1063            .text_tokenizer
1064            .encode(chat_text.as_str(), false)
1065            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
1066        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
1067        // role = input_ids[..3], text_content = input_ids[3..input_ids.len()-5]
1068        let role_ids = &input_ids[..3];
1069        let text_content_ids = &input_ids[3..input_ids.len().saturating_sub(5)];
1070
1071        // Tokenize ref_text
1072        let ref_chat_text = format!("<|im_start|>assistant\n{ref_text}<|im_end|>\n");
1073        let ref_encoding = self
1074            .text_tokenizer
1075            .encode(ref_chat_text.as_str(), false)
1076            .map_err(|e| FerrumError::model(format!("tokenize ref: {e}")))?;
1077        let ref_ids: Vec<u32> = ref_encoding.get_ids().to_vec();
1078        // ref text content: ref_ids[3..ref_ids.len()-2]
1079        let ref_text_ids = &ref_ids[3..ref_ids.len().saturating_sub(2)];
1080
1081        // Step 5: Build prefill prompt (dual-stream text+codec summing)
1082        self.talker.reset();
1083
1084        let tts_bos = self.config.tts_bos_token_id;
1085        let tts_eos = self.config.tts_eos_token_id;
1086        let tts_pad = self.config.tts_pad_token_id;
1087        let codec_bos = self.config.codec_bos_id;
1088        let codec_eos = self.config.codec_eos_token_id;
1089        let codec_pad = self.config.codec_pad_id;
1090
1091        // Helper: embed codec and text tokens
1092        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
1093            let t = Tensor::new(ids, &device)
1094                .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
1095                .unsqueeze(0)
1096                .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
1097            self.talker.embed_codec(&t)
1098        };
1099        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
1100            let t = Tensor::new(ids, &device)
1101                .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
1102                .unsqueeze(0)
1103                .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
1104            self.talker.embed_text(&t)
1105        };
1106
1107        // Get tts special embeddings
1108        let tts_special = embed_text_ids(&[tts_bos, tts_eos, tts_pad])?;
1109        let tts_bos_embed = tts_special
1110            .narrow(1, 0, 1)
1111            .map_err(|e| FerrumError::model(format!("tts_bos narrow: {e}")))?;
1112        let tts_eos_embed = tts_special
1113            .narrow(1, 1, 1)
1114            .map_err(|e| FerrumError::model(format!("tts_eos narrow: {e}")))?;
1115        let tts_pad_embed = tts_special
1116            .narrow(1, 2, 1)
1117            .map_err(|e| FerrumError::model(format!("tts_pad narrow: {e}")))?;
1118
1119        // Resolve language_id — "auto" defaults to "chinese"
1120        let resolved_lang = if language.eq_ignore_ascii_case("auto") {
1121            "chinese"
1122        } else {
1123            language
1124        };
1125        let language_id = self
1126            .config
1127            .codec_language_id
1128            .get(&resolved_lang.to_lowercase());
1129
1130        // Codec prefix: [think, think_bos, lang, think_eos] or [nothink, think_bos, think_eos]
1131        let codec_prefix_ids = if let Some(&lang_id) = language_id {
1132            vec![
1133                self.config.codec_think_id,
1134                self.config.codec_think_bos_id,
1135                lang_id,
1136                self.config.codec_think_eos_id,
1137            ]
1138        } else {
1139            vec![
1140                self.config.codec_nothink_id,
1141                self.config.codec_think_bos_id,
1142                self.config.codec_think_eos_id,
1143            ]
1144        };
1145        let codec_prefix_embed = embed_codec_ids(&codec_prefix_ids)?;
1146
1147        // Codec suffix: [pad, bos]
1148        let codec_suffix_embed = embed_codec_ids(&[codec_pad, codec_bos])?;
1149
1150        // Speaker embed inserted between prefix and suffix
1151        let codec_input = Tensor::cat(&[&codec_prefix_embed, &spk_embed, &codec_suffix_embed], 1)
1152            .map_err(|e| FerrumError::model(format!("codec_input cat: {e}")))?;
1153        let codec_len = codec_input
1154            .dim(1)
1155            .map_err(|e| FerrumError::model(format!("codec_len dim: {e}")))?;
1156
1157        // Role embedding
1158        let role_embed = embed_text_ids(role_ids)?;
1159
1160        // Text-codec prefix: (codec_len - 2) pads + tts_bos, summed with codec[:-1]
1161        let n_pads = codec_len - 2;
1162        let mut text_prefix_parts = Vec::new();
1163        for _ in 0..n_pads {
1164            text_prefix_parts.push(tts_pad_embed.clone());
1165        }
1166        text_prefix_parts.push(tts_bos_embed.clone());
1167        let text_prefix_refs: Vec<&Tensor> = text_prefix_parts.iter().collect();
1168        let text_prefix = Tensor::cat(&text_prefix_refs, 1)
1169            .map_err(|e| FerrumError::model(format!("text_prefix cat: {e}")))?;
1170        let codec_prefix_part = codec_input
1171            .narrow(1, 0, codec_len - 1)
1172            .map_err(|e| FerrumError::model(format!("codec prefix narrow: {e}")))?;
1173        let text_codec_prefix = (&text_prefix + &codec_prefix_part)
1174            .map_err(|e| FerrumError::model(format!("text+codec prefix sum: {e}")))?;
1175
1176        // ICL mode: prefill is 9 positions (no first_text+codec_bos)
1177        let prefill_embed = Tensor::cat(&[&role_embed, &text_codec_prefix], 1)
1178            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
1179
1180        let t3 = std::time::Instant::now();
1181
1182        // Step 6b: Build ICL block — [ref_text, target_text, tts_eos] + [codec_bos, ref_codec]
1183        // Streaming mode: element-wise sum, trailing = text beyond codec length
1184        let all_text_ids: Vec<u32> = ref_text_ids
1185            .iter()
1186            .chain(text_content_ids.iter())
1187            .copied()
1188            .collect();
1189        let text_embed = embed_text_ids(&all_text_ids)?;
1190        let text_embed_with_eos = Tensor::cat(&[&text_embed, &tts_eos_embed], 1)
1191            .map_err(|e| FerrumError::model(format!("text+eos cat: {e}")))?;
1192        let text_len = text_embed_with_eos
1193            .dim(1)
1194            .map_err(|e| FerrumError::model(format!("text_len dim: {e}")))?;
1195
1196        // Codec stream: batch all codebook embeddings in one shot
1197        // Collect all first-codebook IDs → single batch embed
1198        let t_codec_start = std::time::Instant::now();
1199        let ncg = self.config.num_code_groups;
1200        let first_codes: Vec<u32> = ref_codes.iter().map(|f| f[0]).collect();
1201        let codec_frames_cat = {
1202            // Batch embed main codec: [1, nframes, hidden]
1203            let mut sum = embed_codec_ids(&first_codes)?;
1204            // Batch embed each sub-codebook and accumulate
1205            for cb in 0..(ncg - 1) {
1206                let codes: Vec<u32> = ref_codes.iter().map(|f| f[cb + 1]).collect();
1207                let codes_t = Tensor::new(codes.as_slice(), &device)
1208                    .map_err(|e| FerrumError::model(format!("batch codes: {e}")))?
1209                    .unsqueeze(0)
1210                    .map_err(|e| FerrumError::model(format!("batch unsqueeze: {e}")))?;
1211                let sub_embed = codes_t
1212                    .apply(&self.sub_talker.codec_embeddings[cb])
1213                    .map_err(|e| FerrumError::model(format!("batch sub_embed: {e}")))?;
1214                sum =
1215                    (sum + sub_embed).map_err(|e| FerrumError::model(format!("batch add: {e}")))?;
1216            }
1217            sum
1218        };
1219        info!(
1220            "Codec embedding: {:.1}ms ({} frames × {} codebooks)",
1221            t_codec_start.elapsed().as_secs_f64() * 1000.0,
1222            ref_codes.len(),
1223            ncg
1224        );
1225
1226        // Prepend codec_bos to codec frames: [codec_bos_embed, codec_frames]
1227        let t_merge = std::time::Instant::now();
1228        let codec_bos_for_icl = embed_codec_ids(&[codec_bos])?;
1229        let icl_codec = Tensor::cat(&[&codec_bos_for_icl, &codec_frames_cat], 1)
1230            .map_err(|e| FerrumError::model(format!("icl_codec cat: {e}")))?;
1231        let codec_icl_len = icl_codec
1232            .dim(1)
1233            .map_err(|e| FerrumError::model(format!("codec_icl_len dim: {e}")))?;
1234
1235        // Build ICL embed: element-wise sum of text and codec (streaming mode)
1236        let icl_trailing: Tensor;
1237        let icl_embed: Tensor;
1238        if text_len > codec_icl_len {
1239            let text_part = text_embed_with_eos
1240                .narrow(1, 0, codec_icl_len)
1241                .map_err(|e| FerrumError::model(format!("text_part narrow: {e}")))?;
1242            icl_embed = (&text_part + &icl_codec)
1243                .map_err(|e| FerrumError::model(format!("text+codec sum: {e}")))?;
1244            icl_trailing = text_embed_with_eos
1245                .narrow(1, codec_icl_len, text_len - codec_icl_len)
1246                .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1247        } else {
1248            let n_pad = codec_icl_len - text_len;
1249            let text_padded = if n_pad > 0 {
1250                let pad_block = tts_pad_embed
1251                    .expand((1, n_pad, self.config.hidden_size))
1252                    .map_err(|e| FerrumError::model(format!("pad expand: {e}")))?;
1253                Tensor::cat(&[&text_embed_with_eos, &pad_block], 1)
1254                    .map_err(|e| FerrumError::model(format!("text_padded cat: {e}")))?
1255            } else {
1256                text_embed_with_eos.clone()
1257            };
1258            icl_embed = (&text_padded + &icl_codec)
1259                .map_err(|e| FerrumError::model(format!("padded+codec sum: {e}")))?;
1260            icl_trailing = tts_pad_embed.clone();
1261        }
1262        let trailing_text_len = icl_trailing
1263            .dim(1)
1264            .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
1265
1266        // Debug: dump values for comparison with reference project
1267        // Step 6c: Run prefill then ICL block as SEPARATE forward passes
1268        let _prefill_out = self.talker.forward_step(&prefill_embed)?;
1269        let t_icl = std::time::Instant::now();
1270        let icl_hidden = self.talker.forward_step(&icl_embed)?;
1271        let icl_len = icl_hidden
1272            .dim(1)
1273            .map_err(|e| FerrumError::model(format!("icl_hidden dim: {e}")))?;
1274        info!(
1275            "ICL block: {:.1}ms ({} tokens), trailing={}",
1276            t_icl.elapsed().as_secs_f64() * 1000.0,
1277            icl_len,
1278            trailing_text_len
1279        );
1280
1281        // Use ICL hidden output for logits and decode
1282        let mut hidden = icl_hidden;
1283        let hidden_len = hidden
1284            .dim(1)
1285            .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1286        let last_hidden = hidden
1287            .narrow(1, hidden_len - 1, 1)
1288            .map_err(|e| FerrumError::model(format!("narrow last: {e}")))?;
1289        if let Ok(v) = last_hidden.flatten_all().and_then(|t| t.to_vec1::<f32>()) {}
1290        let current_logits = self.talker.logits(&last_hidden)?;
1291        {}
1292
1293        // Decode loop
1294        let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
1295        let mut current_logits = current_logits;
1296
1297        // Suppress special tokens [vocab_size-1024, vocab_size) except EOS
1298        let suppress_start = self.config.vocab_size.saturating_sub(1024);
1299        let suppress_end = self.config.vocab_size;
1300
1301        // ICL mode: stronger repetition penalty (matching reference Rust project)
1302        // + repetition detection for early stop.
1303        const ICL_REPETITION_PENALTY: f32 = 1.5;
1304        const ICL_FRAMES_PER_TOKEN: usize = 6;
1305        const ICL_MIN_FRAMES: usize = 75;
1306        let max_icl_tokens = ICL_MIN_FRAMES.max(text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1307        let mut generated_tokens: Vec<u32> = Vec::new();
1308
1309        for step in 0..max_icl_tokens {
1310            let mut logits_vec = logits_to_vec(&current_logits)?;
1311            // Suppress special tokens [vocab-1024, vocab) except EOS
1312            for i in suppress_start..suppress_end.min(logits_vec.len()) {
1313                if i as u32 != codec_eos {
1314                    logits_vec[i] = f32::NEG_INFINITY;
1315                }
1316            }
1317            // min_new_tokens: suppress EOS until we've generated a minimum
1318            // number of frames. Reference Python uses sentence-length heuristic.
1319            // FERRUM_TTS_MIN_FRAMES env lets us tune without rebuilding.
1320            let min_frames = tts_runtime_env()
1321                .min_frames
1322                .unwrap_or_else(|| text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1323            if step < min_frames {
1324                if let Some(v) = logits_vec.get_mut(codec_eos as usize) {
1325                    *v = f32::NEG_INFINITY;
1326                }
1327            }
1328            // Repetition penalty with token history
1329            for &prev_tok in &generated_tokens {
1330                let idx = prev_tok as usize;
1331                if idx < logits_vec.len() {
1332                    if logits_vec[idx] > 0.0 {
1333                        logits_vec[idx] /= ICL_REPETITION_PENALTY;
1334                    } else {
1335                        logits_vec[idx] *= ICL_REPETITION_PENALTY;
1336                    }
1337                }
1338            }
1339            let next_token = sample_token(
1340                &logits_vec,
1341                tts_temperature(),
1342                TOP_K,
1343                ICL_REPETITION_PENALTY,
1344            );
1345
1346            generated_tokens.push(next_token);
1347
1348            if next_token == codec_eos {
1349                info!("TTS voice clone: codec EOS at step {}", step);
1350                break;
1351            }
1352
1353            // Repetition detection: check for repeating patterns of length 1-4
1354            if generated_tokens.len() >= 6 {
1355                let n = generated_tokens.len();
1356                let mut is_repeat = false;
1357                for pat_len in 1..=4 {
1358                    if n >= pat_len * 3 {
1359                        let a = &generated_tokens[n - pat_len * 3..n - pat_len * 2];
1360                        let b = &generated_tokens[n - pat_len * 2..n - pat_len];
1361                        let c = &generated_tokens[n - pat_len..n];
1362                        if a == b && b == c {
1363                            is_repeat = true;
1364                            break;
1365                        }
1366                    }
1367                }
1368                if is_repeat {
1369                    info!(
1370                        "TTS voice clone: repetition detected at step {}, stopping",
1371                        step
1372                    );
1373                    break;
1374                }
1375            }
1376
1377            let cur_hidden_len = hidden
1378                .dim(1)
1379                .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1380            let last_hidden = hidden
1381                .narrow(1, cur_hidden_len - 1, 1)
1382                .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
1383
1384            let token_tensor = Tensor::new(&[next_token], &device)
1385                .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
1386                .unsqueeze(0)
1387                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
1388            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
1389
1390            let extra_codes = self.sub_talker.predict(
1391                &last_hidden,
1392                &first_codec_embed,
1393                st_temperature(),
1394                TOP_K,
1395            )?;
1396
1397            let mut frame_codes = vec![next_token];
1398            frame_codes.extend_from_slice(&extra_codes);
1399            all_codec_tokens.push(frame_codes);
1400
1401            // Sum all codebook embeddings on GPU
1402            let mut combined_embed = first_codec_embed.clone();
1403            for (i, &code) in extra_codes.iter().enumerate() {
1404                let code_t = Tensor::new(&[code], &device)
1405                    .and_then(|t| t.unsqueeze(0))
1406                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
1407                let sub_embed = code_t
1408                    .apply(&self.sub_talker.codec_embeddings[i])
1409                    .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
1410                combined_embed = (combined_embed + sub_embed)
1411                    .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
1412            }
1413
1414            // Add trailing text or tts_pad (matching reference streaming mode)
1415            if step < trailing_text_len {
1416                let trail = icl_trailing
1417                    .narrow(1, step, 1)
1418                    .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1419                combined_embed = (combined_embed + trail)
1420                    .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
1421            } else {
1422                combined_embed = (combined_embed + &tts_pad_embed)
1423                    .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
1424            }
1425
1426            hidden = self.talker.forward_step(&combined_embed)?;
1427            current_logits = self.talker.logits(&hidden)?;
1428        }
1429
1430        if all_codec_tokens.is_empty() {
1431            return Err(FerrumError::model("no codec tokens generated"));
1432        }
1433        info!(
1434            "TTS voice clone: generated {} codec frames",
1435            all_codec_tokens.len()
1436        );
1437
1438        // Step 7: Prepend ref codes and decode with vocoder
1439        let mut all_codes_with_ref = ref_codes.clone();
1440        all_codes_with_ref.extend_from_slice(&all_codec_tokens);
1441
1442        let num_frames = all_codes_with_ref.len();
1443        let num_groups = self.config.num_code_groups;
1444
1445        // Build codec tensor [1, num_groups, T]
1446        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
1447        for (t, frame) in all_codes_with_ref.iter().enumerate() {
1448            for (g, &code) in frame.iter().take(num_groups).enumerate() {
1449                flat_codes[g * num_frames + t] = code;
1450            }
1451        }
1452
1453        // Clamp to valid range
1454        let codebook_size = 2048u32;
1455        for code in &mut flat_codes {
1456            if *code >= codebook_size {
1457                *code = 0;
1458            }
1459        }
1460
1461        let codes_tensor = Tensor::new(&flat_codes[..], &device)
1462            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
1463            .reshape((1, num_groups, num_frames))
1464            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
1465
1466        let waveform = self.vocoder.decode(&codes_tensor)?;
1467
1468        let samples: Vec<f32> = waveform
1469            .squeeze(0)
1470            .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
1471            .squeeze(0)
1472            .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
1473            .to_vec1()
1474            .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
1475
1476        // Trim reference portion
1477        let ref_ratio = ref_frames as f64 / num_frames as f64;
1478        let cut = (ref_ratio * samples.len() as f64) as usize;
1479        let output_samples = samples[cut..].to_vec();
1480
1481        info!(
1482            "TTS voice clone: waveform {} samples ({:.2}s), trimmed ref {} samples",
1483            output_samples.len(),
1484            output_samples.len() as f64 / SAMPLE_RATE as f64,
1485            cut,
1486        );
1487
1488        Ok(output_samples)
1489    }
1490}
1491
1492// ── Utility functions ───────────────────────────────────────────────────
1493
1494/// Find safetensor files matching a prefix in a directory.
1495fn find_safetensor_files(dir: &std::path::Path, prefix: &str) -> Result<Vec<std::path::PathBuf>> {
1496    // Try single file first
1497    let single = dir.join(format!("{prefix}.safetensors"));
1498    if single.exists() {
1499        return Ok(vec![single]);
1500    }
1501
1502    // Try sharded: prefix-00001-of-00005.safetensors, etc.
1503    let mut files: Vec<std::path::PathBuf> = Vec::new();
1504    if let Ok(entries) = std::fs::read_dir(dir) {
1505        for entry in entries.flatten() {
1506            let path = entry.path();
1507            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1508                if name.starts_with(prefix)
1509                    && name.ends_with(".safetensors")
1510                    && name != format!("{prefix}.safetensors")
1511                {
1512                    files.push(path);
1513                }
1514            }
1515        }
1516    }
1517    files.sort();
1518
1519    if files.is_empty() {
1520        Err(FerrumError::model(format!(
1521            "no safetensors files with prefix '{prefix}' in {}",
1522            dir.display()
1523        )))
1524    } else {
1525        Ok(files)
1526    }
1527}
1528
1529/// Load a BPE tokenizer from vocab.json + merges.txt.
1530fn load_bpe_tokenizer(dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
1531    // Try tokenizer.json first (HF fast tokenizer format)
1532    let tokenizer_json = dir.join("tokenizer.json");
1533    if tokenizer_json.exists() {
1534        return tokenizers::Tokenizer::from_file(&tokenizer_json)
1535            .map_err(|e| FerrumError::model(format!("load tokenizer.json: {e}")));
1536    }
1537
1538    // Fallback: build from vocab.json + merges.txt
1539    let vocab_path = dir.join("vocab.json");
1540    let merges_path = dir.join("merges.txt");
1541
1542    if !vocab_path.exists() || !merges_path.exists() {
1543        return Err(FerrumError::model(
1544            "tokenizer.json not found, and vocab.json + merges.txt not found either",
1545        ));
1546    }
1547
1548    let vocab_data = std::fs::read_to_string(&vocab_path)
1549        .map_err(|e| FerrumError::model(format!("read vocab.json: {e}")))?;
1550    let vocab: HashMap<String, u32> = serde_json::from_str(&vocab_data)
1551        .map_err(|e| FerrumError::model(format!("parse vocab.json: {e}")))?;
1552
1553    let merges_data = std::fs::read_to_string(&merges_path)
1554        .map_err(|e| FerrumError::model(format!("read merges.txt: {e}")))?;
1555    let merges: Vec<(String, String)> = merges_data
1556        .lines()
1557        .skip(1) // skip header line
1558        .filter(|line| !line.is_empty())
1559        .filter_map(|line| {
1560            let parts: Vec<&str> = line.splitn(2, ' ').collect();
1561            if parts.len() == 2 {
1562                Some((parts[0].to_string(), parts[1].to_string()))
1563            } else {
1564                None
1565            }
1566        })
1567        .collect();
1568
1569    let bpe = tokenizers::models::bpe::BPE::from_file(
1570        vocab_path.to_str().unwrap(),
1571        merges_path.to_str().unwrap(),
1572    )
1573    .build()
1574    .map_err(|e| FerrumError::model(format!("build BPE: {e}")))?;
1575
1576    let tokenizer = tokenizers::Tokenizer::new(bpe);
1577    Ok(tokenizer)
1578}
1579
1580/// Extract logits as Vec<f32> from a [1, 1, vocab] or [1, vocab] tensor.
1581fn logits_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
1582    let logits = if logits.dims().len() == 3 {
1583        logits
1584            .squeeze(0)
1585            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1586            .squeeze(0)
1587            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1588    } else if logits.dims().len() == 2 {
1589        logits
1590            .squeeze(0)
1591            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1592    } else {
1593        logits.clone()
1594    };
1595
1596    logits
1597        .to_vec1()
1598        .map_err(|e| FerrumError::model(format!("logits to_vec1: {e}")))
1599}
1600
1601/// Sample a token from logits with temperature, top-k, and repetition penalty.
1602/// Sample a token matching qwen3-tts-rs reference:
1603/// 1. temperature scaling
1604/// 2. top-k filter (keep top_k, rest = -inf)
1605/// 3. top-p filter (keep smallest set with cumprob > top_p, rest = -inf)
1606/// 4. softmax over filtered logits
1607/// 5. multinomial sample from distribution
1608pub fn sample_token(
1609    logits: &[f32],
1610    temperature: f32,
1611    top_k: usize,
1612    _repetition_penalty: f32,
1613) -> u32 {
1614    if temperature < 0.01 {
1615        return argmax(logits);
1616    }
1617
1618    let vocab = logits.len();
1619
1620    // 1. Apply temperature
1621    let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
1622
1623    // 2. Top-k filter: keep top_k values, set rest to -inf
1624    let mut filtered = scaled.clone();
1625    if top_k > 0 && top_k < vocab {
1626        let mut sorted = scaled.clone();
1627        sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1628        let threshold = sorted[top_k - 1];
1629        for v in &mut filtered {
1630            if *v < threshold {
1631                *v = f32::NEG_INFINITY;
1632            }
1633        }
1634    }
1635
1636    // 3. Top-p filter: keep smallest set of tokens whose cumulative prob >= top_p
1637    const TOP_P: f32 = 0.9;
1638    {
1639        let mut indices: Vec<usize> = (0..vocab).collect();
1640        indices.sort_unstable_by(|&a, &b| {
1641            filtered[b]
1642                .partial_cmp(&filtered[a])
1643                .unwrap_or(std::cmp::Ordering::Equal)
1644        });
1645
1646        // Softmax over sorted values for cumulative prob
1647        let max_val = filtered[indices[0]];
1648        let exp_sorted: Vec<f32> = indices
1649            .iter()
1650            .map(|&i| (filtered[i] - max_val).exp())
1651            .collect();
1652        let sum: f32 = exp_sorted.iter().sum();
1653        let probs_sorted: Vec<f32> = exp_sorted.iter().map(|e| e / sum).collect();
1654
1655        // Find cutoff
1656        let mut cumsum = 0.0f32;
1657        let mut cutoff_idx = vocab;
1658        for (i, &p) in probs_sorted.iter().enumerate() {
1659            cumsum += p;
1660            if cumsum > TOP_P {
1661                cutoff_idx = i + 1;
1662                break;
1663            }
1664        }
1665
1666        // Mask out tokens beyond cutoff
1667        for &idx in &indices[cutoff_idx..] {
1668            filtered[idx] = f32::NEG_INFINITY;
1669        }
1670    }
1671
1672    // 4. Softmax over filtered logits
1673    let max_val = filtered.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1674    let exps: Vec<f32> = filtered.iter().map(|&v| (v - max_val).exp()).collect();
1675    let sum: f32 = exps.iter().sum();
1676    let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
1677
1678    // 5. Multinomial sample
1679    let r = rand_f32();
1680    let mut cumulative = 0.0f32;
1681    for (i, &p) in probs.iter().enumerate() {
1682        cumulative += p;
1683        if cumulative >= r {
1684            return i as u32;
1685        }
1686    }
1687    // Fallback
1688    argmax(&probs)
1689}
1690
1691fn argmax(v: &[f32]) -> u32 {
1692    v.iter()
1693        .enumerate()
1694        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1695        .map(|(i, _)| i as u32)
1696        .unwrap_or(0)
1697}
1698
1699/// RNG matching qwen3-tts-rs: LCG with subsec_nanos seed + counter
1700fn rand_f32() -> f32 {
1701    use std::sync::atomic::{AtomicU64, Ordering};
1702    static COUNTER: AtomicU64 = AtomicU64::new(0);
1703
1704    let seed = std::time::SystemTime::now()
1705        .duration_since(std::time::UNIX_EPOCH)
1706        .unwrap_or_default()
1707        .subsec_nanos() as u64;
1708    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1709
1710    let state = seed
1711        .wrapping_add(count)
1712        .wrapping_mul(1103515245)
1713        .wrapping_add(12345);
1714    (state as f32) / (u64::MAX as f32)
1715}
1716
1717// ── Dummy KV cache + ModelExecutor trait impl ───────────────────────────
1718
1719#[derive(Clone, Debug)]
1720#[allow(dead_code)]
1721struct DummyTtsCache;
1722
1723impl ferrum_interfaces::KvCacheHandle for DummyTtsCache {
1724    fn block_table(&self) -> &ferrum_interfaces::BlockTable {
1725        static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
1726            std::sync::OnceLock::new();
1727        EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
1728    }
1729    fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
1730        unimplemented!()
1731    }
1732    fn as_any(&self) -> &dyn std::any::Any {
1733        self
1734    }
1735    fn device(&self) -> Device {
1736        Device::CPU
1737    }
1738    fn num_layers(&self) -> usize {
1739        0
1740    }
1741    fn num_heads(&self) -> usize {
1742        0
1743    }
1744    fn head_dim(&self) -> usize {
1745        0
1746    }
1747    fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1748        Ok(None)
1749    }
1750    fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1751        Ok(None)
1752    }
1753    fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
1754        Ok(Arc::new(self.clone()))
1755    }
1756    fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
1757        ferrum_interfaces::CacheHandleStats {
1758            memory_bytes: 0,
1759            blocks_allocated: 0,
1760            tokens_stored: 0,
1761            utilization: 0.0,
1762            last_access: std::time::Instant::now(),
1763        }
1764    }
1765    fn is_valid(&self) -> bool {
1766        true
1767    }
1768    fn cache_id(&self) -> String {
1769        "tts_dummy".to_string()
1770    }
1771}
1772
1773#[async_trait]
1774impl ModelExecutor for TtsModelExecutor {
1775    fn info(&self) -> &ModelInfo {
1776        &self.info
1777    }
1778
1779    async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
1780        Err(FerrumError::model(
1781            "TTS uses synthesize(), not prefill/decode",
1782        ))
1783    }
1784
1785    async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
1786        Err(FerrumError::model(
1787            "TTS uses synthesize(), not prefill/decode",
1788        ))
1789    }
1790
1791    fn capabilities(&self) -> ExecutorCapabilities {
1792        ExecutorCapabilities {
1793            max_batch_size: 1,
1794            max_sequence_length: self.info.max_sequence_length,
1795            attention_mechanisms: vec![AttentionType::GroupedQuery],
1796            supports_dynamic_batching: false,
1797            supports_continuous_batching: false,
1798            supports_speculative_decoding: false,
1799            supports_tensor_parallelism: false,
1800            supports_pipeline_parallelism: false,
1801            supported_dtypes: vec![DataType::FP32, DataType::BF16],
1802            supported_devices: vec![self.info.device.clone()],
1803            memory_requirements: MemoryRequirements {
1804                parameter_memory: 0,
1805                activation_memory_per_token: 0,
1806                kv_cache_memory_per_token: 0,
1807                overhead_memory: 0,
1808            },
1809        }
1810    }
1811
1812    fn release_cache(&self, _: &str) {}
1813
1814    fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
1815        common::default_executor_status()
1816    }
1817}
1818
1819#[cfg(test)]
1820mod tests {
1821    use super::*;
1822
1823    #[test]
1824    fn tts_runtime_env_parses_overrides() {
1825        let env = TtsRuntimeEnv::from_env_vars([
1826            ("FERRUM_TTS_TEMP", "0.7"),
1827            ("FERRUM_ST_TEMP", "0.2"),
1828            ("FERRUM_REF_PCM", "/tmp/ref.pcm"),
1829            ("FERRUM_REF_CODES", "/tmp/ref.codes"),
1830            ("FERRUM_TTS_MIN_FRAMES", "128"),
1831        ]);
1832
1833        assert_eq!(env.tts_temperature, 0.7);
1834        assert_eq!(env.st_temperature(), 0.2);
1835        assert_eq!(env.ref_pcm.as_deref(), Some("/tmp/ref.pcm"));
1836        assert_eq!(env.ref_codes.as_deref(), Some("/tmp/ref.codes"));
1837        assert_eq!(env.min_frames, Some(128));
1838    }
1839
1840    #[test]
1841    fn tts_runtime_env_defaults_invalid_values() {
1842        let env = TtsRuntimeEnv::from_env_vars([
1843            ("FERRUM_TTS_TEMP", "invalid"),
1844            ("FERRUM_ST_TEMP", "invalid"),
1845            ("FERRUM_TTS_MIN_FRAMES", "invalid"),
1846        ]);
1847
1848        assert_eq!(env.tts_temperature, TEMPERATURE);
1849        assert_eq!(env.st_temperature(), TEMPERATURE);
1850        assert_eq!(env.min_frames, None);
1851    }
1852}