Skip to main content

ferrum_models/executor/
tts_executor.rs

1//! Qwen3-TTS Executor — text-to-speech pipeline wiring Talker LM + Vocoder.
2//!
3//! Implements: text tokenization, autoregressive codec token generation,
4//! SubTalker code prediction (TODO), vocoder waveform synthesis.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use candle_core::{DType, Device as CandleDevice, Tensor};
11use candle_nn::VarBuilder;
12use ferrum_interfaces::{
13    model_executor::{
14        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
15        PrefillInput, PrefillOutput,
16    },
17    ModelExecutor, TensorRef,
18};
19use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
20use tracing::info;
21
22use super::common;
23use crate::architectures::qwen3_tts::{Qwen3TTSTalker, SubTalker, TalkerConfig};
24use crate::architectures::qwen3_tts_backbone::TalkerBackboneBackend;
25use crate::architectures::qwen3_tts_vocoder::{Qwen3TTSVocoder, VocoderConfig};
26use crate::architectures::speaker_encoder::{mel_spectrogram_speaker_encoder, SpeakerEncoder};
27use crate::architectures::speech_tokenizer_encoder::SpeechTokenizerEncoder;
28use ferrum_quantization::NativeSafetensorsLoader;
29
30/// Install `Qwen3TtsTalker`/`SubTalker` Backend<CudaBackend> overrides so
31/// the transformer stack runs via ferrum-kernels cuBLAS + CUDA kernels
32/// instead of the broken fused-on-Linux CPU fallback.
33#[cfg(feature = "cuda")]
34fn install_cuda_backend_overrides(
35    cfg: &TalkerConfig,
36    model_dir: &std::path::Path,
37    talker: &mut Qwen3TTSTalker,
38    sub_talker: &mut SubTalker,
39) -> Result<()> {
40    use ferrum_kernels::backend::cuda::CudaBackend;
41    let loader: NativeSafetensorsLoader<CudaBackend> = NativeSafetensorsLoader::open(model_dir)?;
42    let talker_bb = TalkerBackboneBackend::<CudaBackend>::new(cfg, &loader)?;
43    talker.set_backend_override(Box::new(talker_bb));
44    let sub_bb = TalkerBackboneBackend::<CudaBackend>::new_code_predictor(cfg, &loader)?;
45    sub_talker.set_backend_override(Box::new(sub_bb));
46    Ok(())
47}
48
49// ── Constants ────────────────────────────────────────────────────────────
50
51const SAMPLE_RATE: usize = 24000;
52const MAX_CODEC_TOKENS: usize = 2000;
53
54/// Sampling parameters for codec token generation.
55/// FERRUM_TTS_TEMP env var overrides (0.0 = greedy, 0.9 = default sampling)
56const TEMPERATURE: f32 = 0.9;
57const TOP_K: usize = 50;
58const REPETITION_PENALTY: f32 = 1.05;
59
60fn tts_temperature() -> f32 {
61    std::env::var("FERRUM_TTS_TEMP")
62        .ok()
63        .and_then(|v| v.parse().ok())
64        .unwrap_or(TEMPERATURE)
65}
66
67fn st_temperature() -> f32 {
68    std::env::var("FERRUM_ST_TEMP")
69        .ok()
70        .and_then(|v| v.parse().ok())
71        .unwrap_or(tts_temperature())
72}
73
74/// Qwen3-TTS executor: text-to-speech synthesis.
75pub struct TtsModelExecutor {
76    talker: Qwen3TTSTalker,
77    sub_talker: SubTalker,
78    vocoder: Qwen3TTSVocoder,
79    text_tokenizer: tokenizers::Tokenizer,
80    config: TalkerConfig,
81    info: ModelInfo,
82    speaker_encoder: Option<SpeakerEncoder>,
83    speech_tokenizer_encoder: Option<SpeechTokenizerEncoder>,
84}
85
86impl TtsModelExecutor {
87    /// Load from model directory containing:
88    /// - config.json (TalkerConfig)
89    /// - model.safetensors (Talker weights)
90    /// - speech_tokenizer/model.safetensors (Vocoder weights)
91    /// - tokenizer_config.json + vocab.json + merges.txt (text tokenizer)
92    pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
93        let dir = std::path::Path::new(model_path);
94
95        // Parse TalkerConfig from config.json
96        let config_json: serde_json::Value = {
97            let config_path = dir.join("config.json");
98            let data = std::fs::read_to_string(&config_path)
99                .map_err(|e| FerrumError::model(format!("read config.json: {e}")))?;
100            serde_json::from_str(&data)
101                .map_err(|e| FerrumError::model(format!("parse config.json: {e}")))?
102        };
103        let config = TalkerConfig::from_json(&config_json)?;
104
105        // Load text tokenizer from vocab.json + merges.txt
106        let text_tokenizer = load_bpe_tokenizer(dir)?;
107
108        // Load Talker weights from model.safetensors (or sharded)
109        let talker_weights = find_safetensor_files(dir, "model")?;
110        let talker_vb = unsafe {
111            VarBuilder::from_mmaped_safetensors(&talker_weights, dtype, &device)
112                .map_err(|e| FerrumError::model(format!("load talker weights: {e}")))?
113        };
114        let mut talker = Qwen3TTSTalker::load(&config, talker_vb.clone(), device.clone())?;
115
116        // Load SubTalker (code predictor) from same weights file
117        let mut sub_talker = SubTalker::load(&config, talker_vb.clone(), device.clone())?;
118
119        // Load Speaker Encoder (for voice cloning, base models only)
120        let spk_enc_dim = config_json
121            .get("speaker_encoder_config")
122            .and_then(|c| c.get("enc_dim"))
123            .and_then(|v| v.as_u64())
124            .unwrap_or(1024) as usize;
125        let speaker_encoder =
126            SpeakerEncoder::load_with_dim(talker_vb.pp("speaker_encoder"), spk_enc_dim)
127                .map_err(|e| {
128                    tracing::warn!("Speaker encoder not available: {e}");
129                    e
130                })
131                .ok();
132
133        // Load Vocoder weights from speech_tokenizer/model.safetensors
134        let vocoder_dir = dir.join("speech_tokenizer");
135        let vocoder_weights = find_safetensor_files(&vocoder_dir, "model")?;
136        let vocoder_vb = unsafe {
137            VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &device)
138                .map_err(|e| FerrumError::model(format!("load vocoder weights: {e}")))?
139        };
140        let vocoder_config = VocoderConfig::default();
141        let vocoder = Qwen3TTSVocoder::load(&vocoder_config, vocoder_vb.clone())?;
142
143        // Load Speech Tokenizer Encoder on CPU — Metal float32 accumulation order
144        // causes transformer output divergence that amplifies through RVQ codebook lookup.
145        // CPU is exact and encoder only runs once per reference audio.
146        let speech_tokenizer_encoder = if vocoder_dir.join("config.json").exists() {
147            let cpu_vb = unsafe {
148                VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &CandleDevice::Cpu)
149                    .map_err(|e| FerrumError::model(format!("load encoder cpu: {e}")))?
150            };
151            SpeechTokenizerEncoder::load(cpu_vb.pp("encoder"), CandleDevice::Cpu)
152                .map_err(|e| {
153                    tracing::warn!("Speech tokenizer encoder not available: {e}");
154                    e
155                })
156                .ok()
157        } else {
158            None
159        };
160
161        // Install a Backend<B>-backed transformer for the Talker/SubTalker
162        // when running on CUDA. The legacy `ferrum_attention::FusedTransformer`
163        // CUDA module is a stub and its Linux CPU fallback uses naive fp64
164        // matmul — the CUDA-only voice-clone regression traces back to that
165        // path accumulating precision drift through the 20-layer decoder.
166        // Routing through LlamaFamilyModel<CudaBackend> gives us cuBLAS +
167        // ferrum-kernels for the transformer stack while keeping candle
168        // embeddings / projection / codec_head unchanged.
169        #[cfg(feature = "cuda")]
170        if matches!(&device, CandleDevice::Cuda(_)) {
171            match install_cuda_backend_overrides(&config, dir, &mut talker, &mut sub_talker) {
172                Ok(()) => {
173                    tracing::info!(
174                        "TtsModelExecutor: Backend<CudaBackend> installed for Talker + SubTalker"
175                    );
176                }
177                Err(e) => {
178                    tracing::warn!(
179                        "TtsModelExecutor: Backend<CudaBackend> install failed ({e}); \
180                         falling back to candle/fused path (CUDA voice-clone may produce garbage)"
181                    );
182                }
183            }
184        }
185
186        let info = ModelInfo {
187            model_id: ferrum_types::ModelId(model_path.to_string()),
188            model_type: ModelType::Custom("qwen3-tts".to_string()),
189            hidden_size: config.hidden_size,
190            vocab_size: config.vocab_size,
191            num_layers: config.num_hidden_layers,
192            num_heads: config.num_attention_heads,
193            num_kv_heads: config.num_key_value_heads,
194            num_parameters: 0,
195            max_sequence_length: config.max_position_embeddings,
196            device: match &device {
197                CandleDevice::Cpu => Device::CPU,
198                CandleDevice::Cuda(_) => Device::CUDA(0),
199                #[cfg(any(target_os = "macos", target_os = "ios"))]
200                CandleDevice::Metal(_) => Device::Metal,
201                #[cfg(not(any(target_os = "macos", target_os = "ios")))]
202                CandleDevice::Metal(_) => Device::CPU,
203            },
204            dtype: match dtype {
205                DType::F32 => DataType::FP32,
206                DType::F16 => DataType::FP16,
207                DType::BF16 => DataType::BF16,
208                _ => DataType::FP32,
209            },
210            version: None,
211            license: None,
212            metadata: HashMap::new(),
213        };
214
215        info!(
216            "TtsModelExecutor: {} (hidden={}, layers={}, codec_groups={})",
217            model_path, config.hidden_size, config.num_hidden_layers, config.num_code_groups,
218        );
219
220        Ok(Self {
221            talker,
222            sub_talker,
223            vocoder,
224            text_tokenizer,
225            config,
226            info,
227            speaker_encoder,
228            speech_tokenizer_encoder,
229        })
230    }
231
232    /// Synthesize speech from text.
233    ///
234    /// Returns PCM samples at 24kHz as Vec<f32>.
235    ///
236    /// Prompt structure (matches Python/qwen3-tts-rs):
237    ///   Prefill: [role_prefix(3)] + [tts_text_prefix(6) + codec_prefix(6)] + [first_text + codec_bos]
238    ///   Trailing: text_projection(remaining_text + tts_eos) — added per decode step
239    pub fn synthesize(&mut self, text: &str, language: &str) -> Result<Vec<f32>> {
240        self.talker.reset();
241
242        let device = self.talker.device().clone();
243
244        // 1. Tokenize text (raw content only, no chat template)
245        let encoding = self
246            .text_tokenizer
247            .encode(text, false)
248            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
249        let content_ids: Vec<u32> = encoding.get_ids().to_vec();
250
251        if content_ids.is_empty() {
252            return Err(FerrumError::model("empty text after tokenization"));
253        }
254
255        info!("TTS: content tokens = {}", content_ids.len());
256
257        let codec_eos = self.config.codec_eos_token_id;
258        let tts_pad = self.config.tts_pad_token_id;
259        let tts_bos = self.config.tts_bos_token_id;
260        let tts_eos = self.config.tts_eos_token_id;
261
262        // Helper: embed text token IDs through text_embedding + text_projection
263        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
264            let t = Tensor::new(ids, &device)
265                .and_then(|t| t.unsqueeze(0))
266                .map_err(|e| FerrumError::model(format!("text ids: {e}")))?;
267            self.talker.embed_text(&t)
268        };
269        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
270            let t = Tensor::new(ids, &device)
271                .and_then(|t| t.unsqueeze(0))
272                .map_err(|e| FerrumError::model(format!("codec ids: {e}")))?;
273            self.talker.embed_codec(&t)
274        };
275
276        // 2. Build prefill (matching qwen3-tts-rs prefill_custom_voice)
277
278        // Role prefix: text_projection([<|im_start|>, assistant, \n])
279        // These are fixed special token IDs from the tokenizer
280        let im_start_id = 151644u32; // <|im_start|>
281        let assistant_id = 77091u32; // "assistant"
282        let newline_id = 198u32; // "\n"
283        let role_prefix_ids = [im_start_id, assistant_id, newline_id];
284
285        info!(
286            "TTS: role_prefix={} content={} tokens",
287            role_prefix_ids.len(),
288            content_ids.len()
289        );
290
291        // Role prefix embedding (text_projection)
292        let role_embed = embed_text_ids(&role_prefix_ids)?;
293
294        // Codec prefix: [think, think_bos, lang, think_eos, speaker, pad]
295        let resolved_lang = if language == "auto" {
296            "chinese"
297        } else {
298            language
299        };
300        let language_id = self
301            .config
302            .codec_language_id
303            .get(&resolved_lang.to_lowercase());
304        let codec_prefix_ids = if let Some(&lang_id) = language_id {
305            vec![
306                self.config.codec_think_id,
307                self.config.codec_think_bos_id,
308                lang_id,
309                self.config.codec_think_eos_id,
310            ]
311        } else {
312            vec![
313                self.config.codec_nothink_id,
314                self.config.codec_think_bos_id,
315                self.config.codec_think_eos_id,
316            ]
317        };
318        // Codec sequence: [think, think_bos, lang, think_eos, SPEAKER, pad, bos]
319        // Speaker: use Vivian (3065) for Chinese, Ryan (3061) for English
320        let speaker_token = if resolved_lang == "chinese" {
321            3065u32
322        } else {
323            3061u32
324        };
325        let codec_full = {
326            let mut v = codec_prefix_ids.clone();
327            v.push(speaker_token);
328            v.push(self.config.codec_pad_id);
329            v.push(self.config.codec_bos_id);
330            v
331        };
332        let codec_embed = embed_codec_ids(&codec_full)?;
333
334        // tts_text_prefix: [tts_pad * (codec_len-1), tts_bos]
335        let n_codec = codec_full.len();
336        let mut tts_prefix_ids = vec![tts_pad; n_codec - 1];
337        tts_prefix_ids.push(tts_bos);
338        let tts_prefix_embed = embed_text_ids(&tts_prefix_ids)?;
339
340        // Sum: tts_prefix + codec_prefix[:-1]
341        let codec_first = codec_embed
342            .narrow(1, 0, n_codec - 1)
343            .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
344        // Actually we need codec_first to have same length as tts_prefix
345        // tts_prefix has n_codec elements, codec_first has n_codec-1
346        // Let me re-read the reference... codec_embed has 6 tokens [think..pad,bos], tts_prefix has 6 [pad*5,bos]
347        // They sum first 6 codec with first 6 tts_prefix, then codec_bos is separate
348        // codec_full has [think, think_bos, lang, think_eos, speaker, pad, bos] = 7 tokens
349        // tts_prefix: [pad*5, bos] overlaid on codec[0:6] (first 6, excluding last bos)
350        // Then: first_text + codec_bos (last token) summed
351        let n_prefix = n_codec - 1; // 6: everything except codec_bos
352        let codec_prefix_part = codec_embed
353            .narrow(1, 0, n_prefix)
354            .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
355
356        // tts text prefix: [pad * (n_prefix-1), bos]
357        let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
358        tts_text_prefix_ids.push(tts_bos);
359        let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
360
361        let codec_hidden = (&tts_text_embed + &codec_prefix_part)
362            .map_err(|e| FerrumError::model(format!("prefix sum: {e}")))?;
363
364        // codec_bos is the last element of codec_full
365        let codec_bos_embed = codec_embed
366            .narrow(1, n_prefix, 1)
367            .map_err(|e| FerrumError::model(format!("codec bos: {e}")))?;
368
369        // First text token + codec_bos (summed)
370        let first_text_combined = if !content_ids.is_empty() {
371            let first_text_embed = embed_text_ids(&content_ids[..1])?;
372            (&first_text_embed + &codec_bos_embed)
373                .map_err(|e| FerrumError::model(format!("first text+bos: {e}")))?
374        } else {
375            codec_bos_embed.clone()
376        };
377
378        // Full prefill: [role_prefix, codec_hidden, first_text+codec_bos]
379        let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
380            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
381
382        let plen = prefill_embeds.dim(1).unwrap_or(0);
383        info!("TTS: prefill_len = {}", plen);
384        // Dump prefill input for comparison with reference
385        if let Ok(v) = prefill_embeds
386            .narrow(0, 0, 1)
387            .and_then(|t| t.narrow(1, 0, 1))
388            .and_then(|t| t.narrow(2, 0, 5))
389            .and_then(|t| t.flatten_all())
390            .and_then(|t| t.to_vec1::<f32>())
391        {
392            info!("  prefill_input pos0[:5] = {:?}", v);
393        }
394        if plen > 0 {
395            if let Ok(v) = prefill_embeds
396                .narrow(0, 0, 1)
397                .and_then(|t| t.narrow(1, plen - 1, 1))
398                .and_then(|t| t.narrow(2, 0, 5))
399                .and_then(|t| t.flatten_all())
400                .and_then(|t| t.to_vec1::<f32>())
401            {
402                info!("  prefill_input pos-1[:5] = {:?}", v);
403            }
404        }
405
406        // 3. Build trailing text: text_projection(remaining_content + tts_eos)
407        let mut trailing_ids: Vec<u32> = if content_ids.len() > 1 {
408            content_ids[1..].to_vec()
409        } else {
410            Vec::new()
411        };
412        trailing_ids.push(tts_eos);
413        let trailing_text_embeds = embed_text_ids(&trailing_ids)?;
414        let trailing_text_len = trailing_text_embeds
415            .dim(1)
416            .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
417        let tts_pad_embed = embed_text_ids(&[tts_pad])?;
418
419        info!("TTS: trailing_text_len = {}", trailing_text_len);
420
421        // 4. Prefill: forward through transformer
422        let mut hidden = self.talker.forward_step(&prefill_embeds)?;
423
424        // Get logits from last position
425        let current_logits = self.talker.logits(
426            &hidden
427                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
428                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
429        )?;
430
431        // 4. Autoregressive decode loop: generate codec token 0 per step
432        let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
433        let mut current_logits = current_logits;
434
435        // Token suppression: mask [vocab_size-1024, vocab_size) except EOS
436        let suppress_start = self.config.vocab_size.saturating_sub(1024);
437        let suppress_end = self.config.vocab_size;
438        let mut generated_tokens: Vec<u32> = Vec::new();
439
440        for step in 0..MAX_CODEC_TOKENS {
441            // Sample next codec token with suppression + repetition penalty
442            let mut logits_vec = logits_to_vec(&current_logits)?;
443            // Suppress special tokens
444            for i in suppress_start..suppress_end.min(logits_vec.len()) {
445                if i as u32 != codec_eos {
446                    logits_vec[i] = f32::NEG_INFINITY;
447                }
448            }
449            // Repetition penalty (matching Python's repetition_penalty=1.05)
450            for &prev_tok in &generated_tokens {
451                let idx = prev_tok as usize;
452                if idx < logits_vec.len() {
453                    if logits_vec[idx] > 0.0 {
454                        logits_vec[idx] /= REPETITION_PENALTY;
455                    } else {
456                        logits_vec[idx] *= REPETITION_PENALTY;
457                    }
458                }
459            }
460            let next_token =
461                sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
462
463            if step < 10 {
464                // Find argmax (greedy token)
465                let argmax_tok = logits_vec
466                    .iter()
467                    .enumerate()
468                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
469                    .map(|(i, v)| (i, *v))
470                    .unwrap_or((0, 0.0));
471                info!(
472                    "TOKEN step={} sampled={} argmax=({}, {:.2})",
473                    step, next_token, argmax_tok.0, argmax_tok.1
474                );
475            }
476
477            generated_tokens.push(next_token);
478
479            // Check for EOS
480            if next_token == codec_eos {
481                info!("TTS: codec EOS at step {}", step);
482                break;
483            }
484
485            // Get last hidden state from talker for SubTalker
486            let last_hidden = hidden
487                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
488                .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
489
490            // Embed first codec token
491            let token_tensor = Tensor::new(&[next_token], &device)
492                .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
493                .unsqueeze(0)
494                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
495            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
496
497            // SubTalker: predict remaining codec tokens 1..num_code_groups-1
498            let st_t0 = std::time::Instant::now();
499            let extra_codes = self.sub_talker.predict(
500                &last_hidden,
501                &first_codec_embed,
502                st_temperature(),
503                TOP_K,
504            )?;
505            if step == 0 {
506                info!(
507                    "  SubTalker: {:.1}ms",
508                    st_t0.elapsed().as_secs_f64() * 1000.0
509                );
510            }
511
512            let mut frame_codes = vec![next_token];
513            frame_codes.extend_from_slice(&extra_codes);
514            all_codec_tokens.push(frame_codes);
515
516            // Build combined embedding for next talker step:
517            // sum of all codec embeddings (token 0 from main talker + tokens 1-15 from sub-talker)
518            let mut combined_embed = first_codec_embed.clone();
519            for (i, &code) in extra_codes.iter().enumerate() {
520                let code_t = Tensor::new(&[code], &device)
521                    .and_then(|t| t.unsqueeze(0))
522                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
523                let sub_embed = code_t
524                    .apply(&self.sub_talker.codec_embeddings[i])
525                    .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
526                combined_embed = (combined_embed + sub_embed)
527                    .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
528            }
529
530            if step == 0 {
531                if let Ok(v) = combined_embed
532                    .flatten_all()
533                    .and_then(|t| t.narrow(0, 0, 5))
534                    .and_then(|t| t.to_vec1::<f32>())
535                {
536                    info!("STEP0 codec_sum[:5] = {:?} (before trailing)", v);
537                }
538            }
539            // Add trailing text embedding (guides generation toward target text)
540            // Python: inputs_embeds = codec_sum + trailing_text_hidden[:, gen_step]
541            // trailing_text = text_projection(remaining_text) + tts_eos
542            // For basic TTS, trailing covers all text tokens after the first
543            if step < trailing_text_len {
544                let trail = trailing_text_embeds
545                    .narrow(1, step, 1)
546                    .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
547                combined_embed = (combined_embed + trail)
548                    .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
549            } else {
550                combined_embed = (combined_embed + &tts_pad_embed)
551                    .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
552            }
553
554            // Debug: dump step 0 components
555            if step == 0 {
556                if let Ok(v) = first_codec_embed
557                    .flatten_all()
558                    .and_then(|t| t.narrow(0, 0, 5))
559                    .and_then(|t| t.to_vec1::<f32>())
560                {
561                    info!("STEP0 semantic[:5] = {:?}", v);
562                }
563                if let Ok(v) = combined_embed
564                    .flatten_all()
565                    .and_then(|t| t.narrow(0, 0, 5))
566                    .and_then(|t| t.to_vec1::<f32>())
567                {
568                    info!("STEP0 combined[:5] = {:?}", v);
569                }
570            }
571            // Forward combined embedding through talker
572            let tk_t0 = std::time::Instant::now();
573            hidden = self.talker.forward_step(&combined_embed)?;
574            current_logits = self.talker.logits(&hidden)?;
575            if step == 0 {
576                info!(
577                    "  Talker step: {:.1}ms",
578                    tk_t0.elapsed().as_secs_f64() * 1000.0
579                );
580            }
581        }
582
583        if all_codec_tokens.is_empty() {
584            return Err(FerrumError::model("no codec tokens generated"));
585        }
586
587        info!("TTS: generated {} codec frames", all_codec_tokens.len());
588
589        // 5. Build codec tensor [1, num_code_groups, T] for vocoder
590        let num_frames = all_codec_tokens.len();
591        let num_groups = self.config.num_code_groups;
592        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
593        for (t, frame) in all_codec_tokens.iter().enumerate() {
594            for (g, &code) in frame.iter().enumerate() {
595                flat_codes[g * num_frames + t] = code;
596            }
597        }
598
599        // Clamp codes to valid codebook range [0, codebook_size-1].
600        // Special tokens (codec_bos, codec_eos, etc.) are >= codebook_size and must be clamped.
601        let codebook_size = 2048u32;
602        for code in &mut flat_codes {
603            if *code >= codebook_size {
604                *code = 0; // replace special tokens with pad/silence
605            }
606        }
607
608        let codes_tensor = Tensor::new(&flat_codes[..], &device)
609            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
610            .reshape((1, num_groups, num_frames))
611            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
612
613        // 6. Vocoder: codec tokens → waveform
614        let waveform = self.vocoder.decode(&codes_tensor)?;
615
616        // Extract samples: [1, 1, samples] → Vec<f32>
617        let samples: Vec<f32> = waveform
618            .squeeze(0)
619            .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
620            .squeeze(0)
621            .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
622            .to_vec1()
623            .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
624
625        info!(
626            "TTS: waveform {} samples ({:.2}s @ {}Hz)",
627            samples.len(),
628            samples.len() as f64 / SAMPLE_RATE as f64,
629            SAMPLE_RATE,
630        );
631
632        Ok(samples)
633    }
634
635    /// Decode a chunk of codec frames to audio samples.
636    /// frames: Vec<Vec<u32>> where each inner Vec has num_code_groups elements.
637    fn decode_frames(&mut self, frames: &[Vec<u32>], device: &CandleDevice) -> Result<Vec<f32>> {
638        let num_frames = frames.len();
639        if num_frames == 0 {
640            return Ok(vec![]);
641        }
642        let num_groups = self.config.num_code_groups;
643        let codebook_size = 2048u32;
644
645        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
646        for (t, frame) in frames.iter().enumerate() {
647            for (g, &code) in frame.iter().take(num_groups).enumerate() {
648                flat_codes[g * num_frames + t] = if code >= codebook_size { 0 } else { code };
649            }
650        }
651
652        let codes_tensor = Tensor::new(&flat_codes[..], device)
653            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
654            .reshape((1, num_groups, num_frames))
655            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
656
657        let waveform = self.vocoder.decode(&codes_tensor)?;
658        waveform
659            .squeeze(0)
660            .and_then(|t| t.squeeze(0))
661            .and_then(|t| t.to_vec1())
662            .map_err(|e| FerrumError::model(format!("waveform extract: {e}")))
663    }
664
665    /// Streaming TTS: calls `on_chunk` with each audio chunk as soon as it's ready.
666    ///
667    /// Each chunk is `chunk_frames` codec frames decoded to audio (~800ms at default 10 frames).
668    /// First chunk arrives after `chunk_frames` decode steps (~2-3s for 0.6B).
669    pub fn synthesize_streaming<F: FnMut(usize, &[f32])>(
670        &mut self,
671        text: &str,
672        language: &str,
673        chunk_frames: usize,
674        mut on_chunk: F,
675    ) -> Result<Vec<Vec<f32>>> {
676        // Reuse existing synthesize setup (prefill + trailing text)
677        // but yield audio in chunks instead of all at once
678        self.talker.reset();
679        let device = self.talker.device().clone();
680
681        let encoding = self
682            .text_tokenizer
683            .encode(text, false)
684            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
685        let content_ids: Vec<u32> = encoding.get_ids().to_vec();
686        if content_ids.is_empty() {
687            return Err(FerrumError::model("empty text after tokenization"));
688        }
689
690        let codec_eos = self.config.codec_eos_token_id;
691        let tts_pad = self.config.tts_pad_token_id;
692        let tts_bos = self.config.tts_bos_token_id;
693        let tts_eos = self.config.tts_eos_token_id;
694
695        // Build embeddings (same as synthesize)
696        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
697            let t = Tensor::new(ids, &device)
698                .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
699                .unsqueeze(0)
700                .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
701            self.talker.embed_text(&t)
702        };
703        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
704            let t = Tensor::new(ids, &device)
705                .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
706                .unsqueeze(0)
707                .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
708            self.talker.embed_codec(&t)
709        };
710
711        // Codec/text prefix (same as synthesize)
712        let resolved_lang = if language.eq_ignore_ascii_case("auto") {
713            "chinese"
714        } else {
715            language
716        };
717        let language_id = self
718            .config
719            .codec_language_id
720            .get(&resolved_lang.to_lowercase());
721        let codec_prefix_ids = if let Some(&lang_id) = language_id {
722            vec![
723                self.config.codec_think_id,
724                self.config.codec_think_bos_id,
725                lang_id,
726                self.config.codec_think_eos_id,
727            ]
728        } else {
729            vec![
730                self.config.codec_nothink_id,
731                self.config.codec_think_bos_id,
732                self.config.codec_think_eos_id,
733            ]
734        };
735        let speaker_token = if resolved_lang == "chinese" {
736            3065u32
737        } else {
738            3061u32
739        };
740        let mut codec_ids = codec_prefix_ids;
741        codec_ids.push(speaker_token);
742        codec_ids.push(self.config.codec_pad_id);
743        codec_ids.push(self.config.codec_bos_id);
744        let codec_embed = embed_codec_ids(&codec_ids)?;
745        let n_codec = codec_embed
746            .dim(1)
747            .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
748        let n_prefix = n_codec - 1;
749        let codec_prefix_part = codec_embed
750            .narrow(1, 0, n_prefix)
751            .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
752
753        let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
754        tts_text_prefix_ids.push(tts_bos);
755        let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
756        let codec_hidden = (&tts_text_embed + &codec_prefix_part)
757            .map_err(|e| FerrumError::model(format!("sum: {e}")))?;
758        let codec_bos_embed = codec_embed
759            .narrow(1, n_prefix, 1)
760            .map_err(|e| FerrumError::model(format!("bos: {e}")))?;
761
762        let role_ids: &[u32] = &[151644, 77091, 198]; // im_start, assistant, \n
763        let role_embed = embed_text_ids(role_ids)?;
764
765        let first_text_combined = if !content_ids.is_empty() {
766            let first_text_embed = embed_text_ids(&content_ids[..1])?;
767            (&first_text_embed + &codec_bos_embed)
768                .map_err(|e| FerrumError::model(format!("first: {e}")))?
769        } else {
770            codec_bos_embed.clone()
771        };
772
773        let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
774            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
775
776        // Trailing text
777        let trailing_text_embeds = if content_ids.len() > 1 {
778            let remaining = embed_text_ids(&content_ids[1..])?;
779            let eos = embed_text_ids(&[tts_eos])?;
780            Tensor::cat(&[&remaining, &eos], 1)
781                .map_err(|e| FerrumError::model(format!("trailing: {e}")))?
782        } else {
783            embed_text_ids(&[tts_eos])?
784        };
785        let trailing_text_len = trailing_text_embeds
786            .dim(1)
787            .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
788        let tts_pad_embed = embed_text_ids(&[tts_pad])?;
789
790        // Prefill
791        let mut hidden = self.talker.forward_step(&prefill_embeds)?;
792        let mut current_logits = self.talker.logits(
793            &hidden
794                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
795                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
796        )?;
797
798        // Streaming decode loop
799        let suppress_start = self.config.vocab_size.saturating_sub(1024);
800        let suppress_end = self.config.vocab_size;
801        let mut generated_tokens: Vec<u32> = Vec::new();
802        let mut frame_buffer: Vec<Vec<u32>> = Vec::new();
803        let mut audio_chunks: Vec<Vec<f32>> = Vec::new();
804
805        for step in 0..MAX_CODEC_TOKENS {
806            let mut logits_vec = logits_to_vec(&current_logits)?;
807            for i in suppress_start..suppress_end.min(logits_vec.len()) {
808                if i as u32 != codec_eos {
809                    logits_vec[i] = f32::NEG_INFINITY;
810                }
811            }
812            for &prev_tok in &generated_tokens {
813                let idx = prev_tok as usize;
814                if idx < logits_vec.len() {
815                    if logits_vec[idx] > 0.0 {
816                        logits_vec[idx] /= REPETITION_PENALTY;
817                    } else {
818                        logits_vec[idx] *= REPETITION_PENALTY;
819                    }
820                }
821            }
822            let next_token =
823                sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
824            generated_tokens.push(next_token);
825
826            if next_token == codec_eos {
827                info!("TTS streaming: EOS at step {}", step);
828                break;
829            }
830
831            let last_hidden = hidden
832                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
833                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
834            let token_tensor = Tensor::new(&[next_token], &device)
835                .and_then(|t| t.unsqueeze(0))
836                .map_err(|e| FerrumError::model(format!("tok: {e}")))?;
837            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
838
839            let extra_codes = self.sub_talker.predict(
840                &last_hidden,
841                &first_codec_embed,
842                st_temperature(),
843                TOP_K,
844            )?;
845
846            let mut frame = vec![next_token];
847            frame.extend_from_slice(&extra_codes);
848            frame_buffer.push(frame);
849
850            // Emit chunk when buffer is full
851            if frame_buffer.len() >= chunk_frames {
852                let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
853                on_chunk(audio_chunks.len(), &chunk_audio);
854                audio_chunks.push(chunk_audio);
855                frame_buffer.clear();
856            }
857
858            // Build combined embed for next step
859            let mut combined_embed = first_codec_embed.clone();
860            for (i, &code) in extra_codes.iter().enumerate() {
861                let code_t = Tensor::new(&[code], &device)
862                    .and_then(|t| t.unsqueeze(0))
863                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
864                let sub_embed = code_t
865                    .apply(&self.sub_talker.codec_embeddings[i])
866                    .map_err(|e| FerrumError::model(format!("sub: {e}")))?;
867                combined_embed = (combined_embed + sub_embed)
868                    .map_err(|e| FerrumError::model(format!("add: {e}")))?;
869            }
870            if step < trailing_text_len {
871                let trail = trailing_text_embeds
872                    .narrow(1, step, 1)
873                    .map_err(|e| FerrumError::model(format!("trail: {e}")))?;
874                combined_embed = (combined_embed + trail)
875                    .map_err(|e| FerrumError::model(format!("add trail: {e}")))?;
876            } else {
877                combined_embed = (combined_embed + &tts_pad_embed)
878                    .map_err(|e| FerrumError::model(format!("pad: {e}")))?;
879            }
880
881            hidden = self.talker.forward_step(&combined_embed)?;
882            current_logits = self.talker.logits(&hidden)?;
883        }
884
885        // Flush remaining frames
886        if !frame_buffer.is_empty() {
887            let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
888            on_chunk(audio_chunks.len(), &chunk_audio);
889            audio_chunks.push(chunk_audio);
890        }
891
892        Ok(audio_chunks)
893    }
894
895    /// Get the output sample rate.
896    pub fn sample_rate(&self) -> usize {
897        SAMPLE_RATE
898    }
899
900    pub fn config(&self) -> &TalkerConfig {
901        &self.config
902    }
903
904    /// Synthesize speech with voice cloning from a reference audio.
905    ///
906    /// Uses ICL (in-context learning) prompting: the reference audio is
907    /// encoded to codec tokens and prepended to the generation prompt,
908    /// along with a speaker embedding extracted via ECAPA-TDNN.
909    ///
910    /// Returns PCM samples at 24kHz as Vec<f32>.
911    pub fn synthesize_voice_clone(
912        &mut self,
913        text: &str,
914        language: &str,
915        ref_audio_path: &str,
916        ref_text: &str,
917    ) -> Result<Vec<f32>> {
918        let device = self.talker.device().clone();
919
920        // Step 1: Load and process reference audio at 24kHz
921        let ref_pcm = if let Ok(path) = std::env::var("FERRUM_REF_PCM") {
922            let data = std::fs::read(&path)
923                .map_err(|e| FerrumError::model(format!("read ref pcm: {e}")))?;
924            let pcm: Vec<f32> = data
925                .chunks(4)
926                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
927                .collect();
928            info!("Loaded ref PCM override: {} samples", pcm.len());
929            pcm
930        } else {
931            crate::audio_processor::load_audio_at_rate(ref_audio_path, 24000)?
932        };
933        info!(
934            "TTS voice clone: loaded ref audio {} samples ({:.2}s)",
935            ref_pcm.len(),
936            ref_pcm.len() as f64 / 24000.0
937        );
938
939        let t0 = std::time::Instant::now();
940        // Step 2: Extract speaker embedding via ECAPA-TDNN
941        let speaker_encoder = self
942            .speaker_encoder
943            .as_ref()
944            .ok_or_else(|| FerrumError::model("speaker encoder not loaded"))?;
945        let mel = mel_spectrogram_speaker_encoder(&ref_pcm);
946        let n_mel_frames = mel.len() / 128;
947        // mel_spectrogram_speaker_encoder returns [T, 128] row-major
948        // forward() internally transposes [1, T, 128] → [1, 128, T]
949        let mel_tensor = Tensor::from_vec(mel, (1, n_mel_frames, 128), &device)
950            .map_err(|e| FerrumError::model(format!("mel tensor: {e}")))?;
951        let spk_embed = speaker_encoder.forward(&mel_tensor)?;
952        info!(
953            "Step 2 (speaker embed): {:.1}ms",
954            t0.elapsed().as_secs_f64() * 1000.0
955        );
956        // spk_embed shape: [enc_dim] -> reshape to [1, 1, hidden_size]
957        let spk_embed = spk_embed
958            .unsqueeze(0)
959            .map_err(|e| FerrumError::model(format!("spk unsqueeze(0): {e}")))?
960            .unsqueeze(0)
961            .map_err(|e| FerrumError::model(format!("spk unsqueeze(0) 2: {e}")))?;
962
963        let t1 = std::time::Instant::now();
964        // Step 3: Encode reference audio to codec tokens (ICL)
965        let speech_enc = self
966            .speech_tokenizer_encoder
967            .as_ref()
968            .ok_or_else(|| FerrumError::model("speech tokenizer encoder not loaded"))?;
969        // Allow pre-computed codec tokens for debugging (FERRUM_REF_CODES=/path/to/codes.bin)
970        let ref_codes = if let Ok(path) = std::env::var("FERRUM_REF_CODES") {
971            let data = std::fs::read(&path)
972                .map_err(|e| FerrumError::model(format!("read ref codes: {e}")))?;
973            let u32s: Vec<u32> = data
974                .chunks(4)
975                .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
976                .collect();
977            let ncb = self.config.num_code_groups;
978            let nframes = u32s.len() / ncb;
979            info!(
980                "Loaded pre-computed ref codes: {} frames from {}",
981                nframes, path
982            );
983            u32s.chunks(ncb).map(|c| c.to_vec()).collect()
984        } else {
985            let codes = speech_enc.encode(&ref_pcm)?;
986            info!(
987                "Step 3 (speech tokenizer): {:.1}ms",
988                t1.elapsed().as_secs_f64() * 1000.0
989            );
990            codes
991        };
992        let ref_frames = ref_codes.len();
993        info!(
994            "TTS voice clone: ref_frames={}, spk_embed loaded",
995            ref_frames
996        );
997        // Debug: dump first 5 codec frames for comparison with Python
998        for i in 0..ref_frames.min(5) {
999            info!("  rust codec frame {}: {:?}", i, &ref_codes[i]);
1000        }
1001
1002        info!(
1003            "Step 3 (speech tokenizer): {:.1}ms",
1004            t1.elapsed().as_secs_f64() * 1000.0
1005        );
1006        let t2 = std::time::Instant::now();
1007        // Step 4: Tokenize target text with chat template
1008        let chat_text = format!("<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n");
1009        let encoding = self
1010            .text_tokenizer
1011            .encode(chat_text.as_str(), false)
1012            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
1013        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
1014        // role = input_ids[..3], text_content = input_ids[3..input_ids.len()-5]
1015        let role_ids = &input_ids[..3];
1016        let text_content_ids = &input_ids[3..input_ids.len().saturating_sub(5)];
1017
1018        // Tokenize ref_text
1019        let ref_chat_text = format!("<|im_start|>assistant\n{ref_text}<|im_end|>\n");
1020        let ref_encoding = self
1021            .text_tokenizer
1022            .encode(ref_chat_text.as_str(), false)
1023            .map_err(|e| FerrumError::model(format!("tokenize ref: {e}")))?;
1024        let ref_ids: Vec<u32> = ref_encoding.get_ids().to_vec();
1025        // ref text content: ref_ids[3..ref_ids.len()-2]
1026        let ref_text_ids = &ref_ids[3..ref_ids.len().saturating_sub(2)];
1027
1028        // Step 5: Build prefill prompt (dual-stream text+codec summing)
1029        self.talker.reset();
1030
1031        let tts_bos = self.config.tts_bos_token_id;
1032        let tts_eos = self.config.tts_eos_token_id;
1033        let tts_pad = self.config.tts_pad_token_id;
1034        let codec_bos = self.config.codec_bos_id;
1035        let codec_eos = self.config.codec_eos_token_id;
1036        let codec_pad = self.config.codec_pad_id;
1037
1038        // Helper: embed codec and text tokens
1039        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
1040            let t = Tensor::new(ids, &device)
1041                .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
1042                .unsqueeze(0)
1043                .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
1044            self.talker.embed_codec(&t)
1045        };
1046        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
1047            let t = Tensor::new(ids, &device)
1048                .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
1049                .unsqueeze(0)
1050                .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
1051            self.talker.embed_text(&t)
1052        };
1053
1054        // Get tts special embeddings
1055        let tts_special = embed_text_ids(&[tts_bos, tts_eos, tts_pad])?;
1056        let tts_bos_embed = tts_special
1057            .narrow(1, 0, 1)
1058            .map_err(|e| FerrumError::model(format!("tts_bos narrow: {e}")))?;
1059        let tts_eos_embed = tts_special
1060            .narrow(1, 1, 1)
1061            .map_err(|e| FerrumError::model(format!("tts_eos narrow: {e}")))?;
1062        let tts_pad_embed = tts_special
1063            .narrow(1, 2, 1)
1064            .map_err(|e| FerrumError::model(format!("tts_pad narrow: {e}")))?;
1065
1066        // Resolve language_id — "auto" defaults to "chinese"
1067        let resolved_lang = if language.eq_ignore_ascii_case("auto") {
1068            "chinese"
1069        } else {
1070            language
1071        };
1072        let language_id = self
1073            .config
1074            .codec_language_id
1075            .get(&resolved_lang.to_lowercase());
1076
1077        // Codec prefix: [think, think_bos, lang, think_eos] or [nothink, think_bos, think_eos]
1078        let codec_prefix_ids = if let Some(&lang_id) = language_id {
1079            vec![
1080                self.config.codec_think_id,
1081                self.config.codec_think_bos_id,
1082                lang_id,
1083                self.config.codec_think_eos_id,
1084            ]
1085        } else {
1086            vec![
1087                self.config.codec_nothink_id,
1088                self.config.codec_think_bos_id,
1089                self.config.codec_think_eos_id,
1090            ]
1091        };
1092        let codec_prefix_embed = embed_codec_ids(&codec_prefix_ids)?;
1093
1094        // Codec suffix: [pad, bos]
1095        let codec_suffix_embed = embed_codec_ids(&[codec_pad, codec_bos])?;
1096
1097        // Speaker embed inserted between prefix and suffix
1098        let codec_input = Tensor::cat(&[&codec_prefix_embed, &spk_embed, &codec_suffix_embed], 1)
1099            .map_err(|e| FerrumError::model(format!("codec_input cat: {e}")))?;
1100        let codec_len = codec_input
1101            .dim(1)
1102            .map_err(|e| FerrumError::model(format!("codec_len dim: {e}")))?;
1103
1104        // Role embedding
1105        let role_embed = embed_text_ids(role_ids)?;
1106
1107        // Text-codec prefix: (codec_len - 2) pads + tts_bos, summed with codec[:-1]
1108        let n_pads = codec_len - 2;
1109        let mut text_prefix_parts = Vec::new();
1110        for _ in 0..n_pads {
1111            text_prefix_parts.push(tts_pad_embed.clone());
1112        }
1113        text_prefix_parts.push(tts_bos_embed.clone());
1114        let text_prefix_refs: Vec<&Tensor> = text_prefix_parts.iter().collect();
1115        let text_prefix = Tensor::cat(&text_prefix_refs, 1)
1116            .map_err(|e| FerrumError::model(format!("text_prefix cat: {e}")))?;
1117        let codec_prefix_part = codec_input
1118            .narrow(1, 0, codec_len - 1)
1119            .map_err(|e| FerrumError::model(format!("codec prefix narrow: {e}")))?;
1120        let text_codec_prefix = (&text_prefix + &codec_prefix_part)
1121            .map_err(|e| FerrumError::model(format!("text+codec prefix sum: {e}")))?;
1122
1123        // ICL mode: prefill is 9 positions (no first_text+codec_bos)
1124        let prefill_embed = Tensor::cat(&[&role_embed, &text_codec_prefix], 1)
1125            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
1126
1127        let t3 = std::time::Instant::now();
1128
1129        // Step 6b: Build ICL block — [ref_text, target_text, tts_eos] + [codec_bos, ref_codec]
1130        // Streaming mode: element-wise sum, trailing = text beyond codec length
1131        let all_text_ids: Vec<u32> = ref_text_ids
1132            .iter()
1133            .chain(text_content_ids.iter())
1134            .copied()
1135            .collect();
1136        let text_embed = embed_text_ids(&all_text_ids)?;
1137        let text_embed_with_eos = Tensor::cat(&[&text_embed, &tts_eos_embed], 1)
1138            .map_err(|e| FerrumError::model(format!("text+eos cat: {e}")))?;
1139        let text_len = text_embed_with_eos
1140            .dim(1)
1141            .map_err(|e| FerrumError::model(format!("text_len dim: {e}")))?;
1142
1143        // Codec stream: batch all codebook embeddings in one shot
1144        // Collect all first-codebook IDs → single batch embed
1145        let t_codec_start = std::time::Instant::now();
1146        let ncg = self.config.num_code_groups;
1147        let first_codes: Vec<u32> = ref_codes.iter().map(|f| f[0]).collect();
1148        let codec_frames_cat = {
1149            // Batch embed main codec: [1, nframes, hidden]
1150            let mut sum = embed_codec_ids(&first_codes)?;
1151            // Batch embed each sub-codebook and accumulate
1152            for cb in 0..(ncg - 1) {
1153                let codes: Vec<u32> = ref_codes.iter().map(|f| f[cb + 1]).collect();
1154                let codes_t = Tensor::new(codes.as_slice(), &device)
1155                    .map_err(|e| FerrumError::model(format!("batch codes: {e}")))?
1156                    .unsqueeze(0)
1157                    .map_err(|e| FerrumError::model(format!("batch unsqueeze: {e}")))?;
1158                let sub_embed = codes_t
1159                    .apply(&self.sub_talker.codec_embeddings[cb])
1160                    .map_err(|e| FerrumError::model(format!("batch sub_embed: {e}")))?;
1161                sum =
1162                    (sum + sub_embed).map_err(|e| FerrumError::model(format!("batch add: {e}")))?;
1163            }
1164            sum
1165        };
1166        info!(
1167            "Codec embedding: {:.1}ms ({} frames × {} codebooks)",
1168            t_codec_start.elapsed().as_secs_f64() * 1000.0,
1169            ref_codes.len(),
1170            ncg
1171        );
1172
1173        // Prepend codec_bos to codec frames: [codec_bos_embed, codec_frames]
1174        let t_merge = std::time::Instant::now();
1175        let codec_bos_for_icl = embed_codec_ids(&[codec_bos])?;
1176        let icl_codec = Tensor::cat(&[&codec_bos_for_icl, &codec_frames_cat], 1)
1177            .map_err(|e| FerrumError::model(format!("icl_codec cat: {e}")))?;
1178        let codec_icl_len = icl_codec
1179            .dim(1)
1180            .map_err(|e| FerrumError::model(format!("codec_icl_len dim: {e}")))?;
1181
1182        // Build ICL embed: element-wise sum of text and codec (streaming mode)
1183        let icl_trailing: Tensor;
1184        let icl_embed: Tensor;
1185        if text_len > codec_icl_len {
1186            let text_part = text_embed_with_eos
1187                .narrow(1, 0, codec_icl_len)
1188                .map_err(|e| FerrumError::model(format!("text_part narrow: {e}")))?;
1189            icl_embed = (&text_part + &icl_codec)
1190                .map_err(|e| FerrumError::model(format!("text+codec sum: {e}")))?;
1191            icl_trailing = text_embed_with_eos
1192                .narrow(1, codec_icl_len, text_len - codec_icl_len)
1193                .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1194        } else {
1195            let n_pad = codec_icl_len - text_len;
1196            let text_padded = if n_pad > 0 {
1197                let pad_block = tts_pad_embed
1198                    .expand((1, n_pad, self.config.hidden_size))
1199                    .map_err(|e| FerrumError::model(format!("pad expand: {e}")))?;
1200                Tensor::cat(&[&text_embed_with_eos, &pad_block], 1)
1201                    .map_err(|e| FerrumError::model(format!("text_padded cat: {e}")))?
1202            } else {
1203                text_embed_with_eos.clone()
1204            };
1205            icl_embed = (&text_padded + &icl_codec)
1206                .map_err(|e| FerrumError::model(format!("padded+codec sum: {e}")))?;
1207            icl_trailing = tts_pad_embed.clone();
1208        }
1209        let trailing_text_len = icl_trailing
1210            .dim(1)
1211            .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
1212
1213        // Debug: dump values for comparison with reference project
1214        // Step 6c: Run prefill then ICL block as SEPARATE forward passes
1215        let _prefill_out = self.talker.forward_step(&prefill_embed)?;
1216        let t_icl = std::time::Instant::now();
1217        let icl_hidden = self.talker.forward_step(&icl_embed)?;
1218        let icl_len = icl_hidden
1219            .dim(1)
1220            .map_err(|e| FerrumError::model(format!("icl_hidden dim: {e}")))?;
1221        info!(
1222            "ICL block: {:.1}ms ({} tokens), trailing={}",
1223            t_icl.elapsed().as_secs_f64() * 1000.0,
1224            icl_len,
1225            trailing_text_len
1226        );
1227
1228        // Use ICL hidden output for logits and decode
1229        let mut hidden = icl_hidden;
1230        let hidden_len = hidden
1231            .dim(1)
1232            .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1233        let last_hidden = hidden
1234            .narrow(1, hidden_len - 1, 1)
1235            .map_err(|e| FerrumError::model(format!("narrow last: {e}")))?;
1236        if let Ok(v) = last_hidden.flatten_all().and_then(|t| t.to_vec1::<f32>()) {}
1237        let current_logits = self.talker.logits(&last_hidden)?;
1238        {}
1239
1240        // Decode loop
1241        let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
1242        let mut current_logits = current_logits;
1243
1244        // Suppress special tokens [vocab_size-1024, vocab_size) except EOS
1245        let suppress_start = self.config.vocab_size.saturating_sub(1024);
1246        let suppress_end = self.config.vocab_size;
1247
1248        // ICL mode: stronger repetition penalty (matching reference Rust project)
1249        // + repetition detection for early stop.
1250        const ICL_REPETITION_PENALTY: f32 = 1.5;
1251        const ICL_FRAMES_PER_TOKEN: usize = 6;
1252        const ICL_MIN_FRAMES: usize = 75;
1253        let max_icl_tokens = ICL_MIN_FRAMES.max(text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1254        let mut generated_tokens: Vec<u32> = Vec::new();
1255
1256        for step in 0..max_icl_tokens {
1257            let mut logits_vec = logits_to_vec(&current_logits)?;
1258            // Suppress special tokens [vocab-1024, vocab) except EOS
1259            for i in suppress_start..suppress_end.min(logits_vec.len()) {
1260                if i as u32 != codec_eos {
1261                    logits_vec[i] = f32::NEG_INFINITY;
1262                }
1263            }
1264            // min_new_tokens: suppress EOS until we've generated a minimum
1265            // number of frames. Reference Python uses sentence-length heuristic.
1266            // FERRUM_TTS_MIN_FRAMES env lets us tune without rebuilding.
1267            let min_frames: usize = std::env::var("FERRUM_TTS_MIN_FRAMES")
1268                .ok()
1269                .and_then(|v| v.parse().ok())
1270                .unwrap_or_else(|| text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1271            if step < min_frames {
1272                if let Some(v) = logits_vec.get_mut(codec_eos as usize) {
1273                    *v = f32::NEG_INFINITY;
1274                }
1275            }
1276            // Repetition penalty with token history
1277            for &prev_tok in &generated_tokens {
1278                let idx = prev_tok as usize;
1279                if idx < logits_vec.len() {
1280                    if logits_vec[idx] > 0.0 {
1281                        logits_vec[idx] /= ICL_REPETITION_PENALTY;
1282                    } else {
1283                        logits_vec[idx] *= ICL_REPETITION_PENALTY;
1284                    }
1285                }
1286            }
1287            let next_token = sample_token(
1288                &logits_vec,
1289                tts_temperature(),
1290                TOP_K,
1291                ICL_REPETITION_PENALTY,
1292            );
1293
1294            generated_tokens.push(next_token);
1295
1296            if next_token == codec_eos {
1297                info!("TTS voice clone: codec EOS at step {}", step);
1298                break;
1299            }
1300
1301            // Repetition detection: check for repeating patterns of length 1-4
1302            if generated_tokens.len() >= 6 {
1303                let n = generated_tokens.len();
1304                let mut is_repeat = false;
1305                for pat_len in 1..=4 {
1306                    if n >= pat_len * 3 {
1307                        let a = &generated_tokens[n - pat_len * 3..n - pat_len * 2];
1308                        let b = &generated_tokens[n - pat_len * 2..n - pat_len];
1309                        let c = &generated_tokens[n - pat_len..n];
1310                        if a == b && b == c {
1311                            is_repeat = true;
1312                            break;
1313                        }
1314                    }
1315                }
1316                if is_repeat {
1317                    info!(
1318                        "TTS voice clone: repetition detected at step {}, stopping",
1319                        step
1320                    );
1321                    break;
1322                }
1323            }
1324
1325            let cur_hidden_len = hidden
1326                .dim(1)
1327                .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1328            let last_hidden = hidden
1329                .narrow(1, cur_hidden_len - 1, 1)
1330                .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
1331
1332            let token_tensor = Tensor::new(&[next_token], &device)
1333                .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
1334                .unsqueeze(0)
1335                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
1336            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
1337
1338            let extra_codes = self.sub_talker.predict(
1339                &last_hidden,
1340                &first_codec_embed,
1341                st_temperature(),
1342                TOP_K,
1343            )?;
1344
1345            let mut frame_codes = vec![next_token];
1346            frame_codes.extend_from_slice(&extra_codes);
1347            all_codec_tokens.push(frame_codes);
1348
1349            // Sum all codebook embeddings on GPU
1350            let mut combined_embed = first_codec_embed.clone();
1351            for (i, &code) in extra_codes.iter().enumerate() {
1352                let code_t = Tensor::new(&[code], &device)
1353                    .and_then(|t| t.unsqueeze(0))
1354                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
1355                let sub_embed = code_t
1356                    .apply(&self.sub_talker.codec_embeddings[i])
1357                    .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
1358                combined_embed = (combined_embed + sub_embed)
1359                    .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
1360            }
1361
1362            // Add trailing text or tts_pad (matching reference streaming mode)
1363            if step < trailing_text_len {
1364                let trail = icl_trailing
1365                    .narrow(1, step, 1)
1366                    .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1367                combined_embed = (combined_embed + trail)
1368                    .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
1369            } else {
1370                combined_embed = (combined_embed + &tts_pad_embed)
1371                    .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
1372            }
1373
1374            hidden = self.talker.forward_step(&combined_embed)?;
1375            current_logits = self.talker.logits(&hidden)?;
1376        }
1377
1378        if all_codec_tokens.is_empty() {
1379            return Err(FerrumError::model("no codec tokens generated"));
1380        }
1381        info!(
1382            "TTS voice clone: generated {} codec frames",
1383            all_codec_tokens.len()
1384        );
1385
1386        // Step 7: Prepend ref codes and decode with vocoder
1387        let mut all_codes_with_ref = ref_codes.clone();
1388        all_codes_with_ref.extend_from_slice(&all_codec_tokens);
1389
1390        let num_frames = all_codes_with_ref.len();
1391        let num_groups = self.config.num_code_groups;
1392
1393        // Build codec tensor [1, num_groups, T]
1394        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
1395        for (t, frame) in all_codes_with_ref.iter().enumerate() {
1396            for (g, &code) in frame.iter().take(num_groups).enumerate() {
1397                flat_codes[g * num_frames + t] = code;
1398            }
1399        }
1400
1401        // Clamp to valid range
1402        let codebook_size = 2048u32;
1403        for code in &mut flat_codes {
1404            if *code >= codebook_size {
1405                *code = 0;
1406            }
1407        }
1408
1409        let codes_tensor = Tensor::new(&flat_codes[..], &device)
1410            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
1411            .reshape((1, num_groups, num_frames))
1412            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
1413
1414        let waveform = self.vocoder.decode(&codes_tensor)?;
1415
1416        let samples: Vec<f32> = waveform
1417            .squeeze(0)
1418            .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
1419            .squeeze(0)
1420            .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
1421            .to_vec1()
1422            .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
1423
1424        // Trim reference portion
1425        let ref_ratio = ref_frames as f64 / num_frames as f64;
1426        let cut = (ref_ratio * samples.len() as f64) as usize;
1427        let output_samples = samples[cut..].to_vec();
1428
1429        info!(
1430            "TTS voice clone: waveform {} samples ({:.2}s), trimmed ref {} samples",
1431            output_samples.len(),
1432            output_samples.len() as f64 / SAMPLE_RATE as f64,
1433            cut,
1434        );
1435
1436        Ok(output_samples)
1437    }
1438}
1439
1440// ── Utility functions ───────────────────────────────────────────────────
1441
1442/// Find safetensor files matching a prefix in a directory.
1443fn find_safetensor_files(dir: &std::path::Path, prefix: &str) -> Result<Vec<std::path::PathBuf>> {
1444    // Try single file first
1445    let single = dir.join(format!("{prefix}.safetensors"));
1446    if single.exists() {
1447        return Ok(vec![single]);
1448    }
1449
1450    // Try sharded: prefix-00001-of-00005.safetensors, etc.
1451    let mut files: Vec<std::path::PathBuf> = Vec::new();
1452    if let Ok(entries) = std::fs::read_dir(dir) {
1453        for entry in entries.flatten() {
1454            let path = entry.path();
1455            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1456                if name.starts_with(prefix)
1457                    && name.ends_with(".safetensors")
1458                    && name != format!("{prefix}.safetensors")
1459                {
1460                    files.push(path);
1461                }
1462            }
1463        }
1464    }
1465    files.sort();
1466
1467    if files.is_empty() {
1468        Err(FerrumError::model(format!(
1469            "no safetensors files with prefix '{prefix}' in {}",
1470            dir.display()
1471        )))
1472    } else {
1473        Ok(files)
1474    }
1475}
1476
1477/// Load a BPE tokenizer from vocab.json + merges.txt.
1478fn load_bpe_tokenizer(dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
1479    // Try tokenizer.json first (HF fast tokenizer format)
1480    let tokenizer_json = dir.join("tokenizer.json");
1481    if tokenizer_json.exists() {
1482        return tokenizers::Tokenizer::from_file(&tokenizer_json)
1483            .map_err(|e| FerrumError::model(format!("load tokenizer.json: {e}")));
1484    }
1485
1486    // Fallback: build from vocab.json + merges.txt
1487    let vocab_path = dir.join("vocab.json");
1488    let merges_path = dir.join("merges.txt");
1489
1490    if !vocab_path.exists() || !merges_path.exists() {
1491        return Err(FerrumError::model(
1492            "tokenizer.json not found, and vocab.json + merges.txt not found either",
1493        ));
1494    }
1495
1496    let vocab_data = std::fs::read_to_string(&vocab_path)
1497        .map_err(|e| FerrumError::model(format!("read vocab.json: {e}")))?;
1498    let vocab: HashMap<String, u32> = serde_json::from_str(&vocab_data)
1499        .map_err(|e| FerrumError::model(format!("parse vocab.json: {e}")))?;
1500
1501    let merges_data = std::fs::read_to_string(&merges_path)
1502        .map_err(|e| FerrumError::model(format!("read merges.txt: {e}")))?;
1503    let merges: Vec<(String, String)> = merges_data
1504        .lines()
1505        .skip(1) // skip header line
1506        .filter(|line| !line.is_empty())
1507        .filter_map(|line| {
1508            let parts: Vec<&str> = line.splitn(2, ' ').collect();
1509            if parts.len() == 2 {
1510                Some((parts[0].to_string(), parts[1].to_string()))
1511            } else {
1512                None
1513            }
1514        })
1515        .collect();
1516
1517    let bpe = tokenizers::models::bpe::BPE::from_file(
1518        vocab_path.to_str().unwrap(),
1519        merges_path.to_str().unwrap(),
1520    )
1521    .build()
1522    .map_err(|e| FerrumError::model(format!("build BPE: {e}")))?;
1523
1524    let tokenizer = tokenizers::Tokenizer::new(bpe);
1525    Ok(tokenizer)
1526}
1527
1528/// Extract logits as Vec<f32> from a [1, 1, vocab] or [1, vocab] tensor.
1529fn logits_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
1530    let logits = if logits.dims().len() == 3 {
1531        logits
1532            .squeeze(0)
1533            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1534            .squeeze(0)
1535            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1536    } else if logits.dims().len() == 2 {
1537        logits
1538            .squeeze(0)
1539            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1540    } else {
1541        logits.clone()
1542    };
1543
1544    logits
1545        .to_vec1()
1546        .map_err(|e| FerrumError::model(format!("logits to_vec1: {e}")))
1547}
1548
1549/// Sample a token from logits with temperature, top-k, and repetition penalty.
1550/// Sample a token matching qwen3-tts-rs reference:
1551/// 1. temperature scaling
1552/// 2. top-k filter (keep top_k, rest = -inf)
1553/// 3. top-p filter (keep smallest set with cumprob > top_p, rest = -inf)
1554/// 4. softmax over filtered logits
1555/// 5. multinomial sample from distribution
1556pub fn sample_token(
1557    logits: &[f32],
1558    temperature: f32,
1559    top_k: usize,
1560    _repetition_penalty: f32,
1561) -> u32 {
1562    if temperature < 0.01 {
1563        return argmax(logits);
1564    }
1565
1566    let vocab = logits.len();
1567
1568    // 1. Apply temperature
1569    let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
1570
1571    // 2. Top-k filter: keep top_k values, set rest to -inf
1572    let mut filtered = scaled.clone();
1573    if top_k > 0 && top_k < vocab {
1574        let mut sorted = scaled.clone();
1575        sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1576        let threshold = sorted[top_k - 1];
1577        for v in &mut filtered {
1578            if *v < threshold {
1579                *v = f32::NEG_INFINITY;
1580            }
1581        }
1582    }
1583
1584    // 3. Top-p filter: keep smallest set of tokens whose cumulative prob >= top_p
1585    const TOP_P: f32 = 0.9;
1586    {
1587        let mut indices: Vec<usize> = (0..vocab).collect();
1588        indices.sort_unstable_by(|&a, &b| {
1589            filtered[b]
1590                .partial_cmp(&filtered[a])
1591                .unwrap_or(std::cmp::Ordering::Equal)
1592        });
1593
1594        // Softmax over sorted values for cumulative prob
1595        let max_val = filtered[indices[0]];
1596        let exp_sorted: Vec<f32> = indices
1597            .iter()
1598            .map(|&i| (filtered[i] - max_val).exp())
1599            .collect();
1600        let sum: f32 = exp_sorted.iter().sum();
1601        let probs_sorted: Vec<f32> = exp_sorted.iter().map(|e| e / sum).collect();
1602
1603        // Find cutoff
1604        let mut cumsum = 0.0f32;
1605        let mut cutoff_idx = vocab;
1606        for (i, &p) in probs_sorted.iter().enumerate() {
1607            cumsum += p;
1608            if cumsum > TOP_P {
1609                cutoff_idx = i + 1;
1610                break;
1611            }
1612        }
1613
1614        // Mask out tokens beyond cutoff
1615        for &idx in &indices[cutoff_idx..] {
1616            filtered[idx] = f32::NEG_INFINITY;
1617        }
1618    }
1619
1620    // 4. Softmax over filtered logits
1621    let max_val = filtered.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1622    let exps: Vec<f32> = filtered.iter().map(|&v| (v - max_val).exp()).collect();
1623    let sum: f32 = exps.iter().sum();
1624    let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
1625
1626    // 5. Multinomial sample
1627    let r = rand_f32();
1628    let mut cumulative = 0.0f32;
1629    for (i, &p) in probs.iter().enumerate() {
1630        cumulative += p;
1631        if cumulative >= r {
1632            return i as u32;
1633        }
1634    }
1635    // Fallback
1636    argmax(&probs)
1637}
1638
1639fn argmax(v: &[f32]) -> u32 {
1640    v.iter()
1641        .enumerate()
1642        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1643        .map(|(i, _)| i as u32)
1644        .unwrap_or(0)
1645}
1646
1647/// RNG matching qwen3-tts-rs: LCG with subsec_nanos seed + counter
1648fn rand_f32() -> f32 {
1649    use std::sync::atomic::{AtomicU64, Ordering};
1650    static COUNTER: AtomicU64 = AtomicU64::new(0);
1651
1652    let seed = std::time::SystemTime::now()
1653        .duration_since(std::time::UNIX_EPOCH)
1654        .unwrap_or_default()
1655        .subsec_nanos() as u64;
1656    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1657
1658    let state = seed
1659        .wrapping_add(count)
1660        .wrapping_mul(1103515245)
1661        .wrapping_add(12345);
1662    (state as f32) / (u64::MAX as f32)
1663}
1664
1665// ── Dummy KV cache + ModelExecutor trait impl ───────────────────────────
1666
1667#[derive(Clone, Debug)]
1668#[allow(dead_code)]
1669struct DummyTtsCache;
1670
1671impl ferrum_interfaces::KvCacheHandle for DummyTtsCache {
1672    fn block_table(&self) -> &ferrum_interfaces::BlockTable {
1673        static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
1674            std::sync::OnceLock::new();
1675        EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
1676    }
1677    fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
1678        unimplemented!()
1679    }
1680    fn as_any(&self) -> &dyn std::any::Any {
1681        self
1682    }
1683    fn device(&self) -> Device {
1684        Device::CPU
1685    }
1686    fn num_layers(&self) -> usize {
1687        0
1688    }
1689    fn num_heads(&self) -> usize {
1690        0
1691    }
1692    fn head_dim(&self) -> usize {
1693        0
1694    }
1695    fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1696        Ok(None)
1697    }
1698    fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1699        Ok(None)
1700    }
1701    fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
1702        Ok(Arc::new(self.clone()))
1703    }
1704    fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
1705        ferrum_interfaces::CacheHandleStats {
1706            memory_bytes: 0,
1707            blocks_allocated: 0,
1708            tokens_stored: 0,
1709            utilization: 0.0,
1710            last_access: std::time::Instant::now(),
1711        }
1712    }
1713    fn is_valid(&self) -> bool {
1714        true
1715    }
1716    fn cache_id(&self) -> String {
1717        "tts_dummy".to_string()
1718    }
1719}
1720
1721#[async_trait]
1722impl ModelExecutor for TtsModelExecutor {
1723    fn info(&self) -> &ModelInfo {
1724        &self.info
1725    }
1726
1727    async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
1728        Err(FerrumError::model(
1729            "TTS uses synthesize(), not prefill/decode",
1730        ))
1731    }
1732
1733    async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
1734        Err(FerrumError::model(
1735            "TTS uses synthesize(), not prefill/decode",
1736        ))
1737    }
1738
1739    fn capabilities(&self) -> ExecutorCapabilities {
1740        ExecutorCapabilities {
1741            max_batch_size: 1,
1742            max_sequence_length: self.info.max_sequence_length,
1743            attention_mechanisms: vec![AttentionType::GroupedQuery],
1744            supports_dynamic_batching: false,
1745            supports_continuous_batching: false,
1746            supports_speculative_decoding: false,
1747            supports_tensor_parallelism: false,
1748            supports_pipeline_parallelism: false,
1749            supported_dtypes: vec![DataType::FP32, DataType::BF16],
1750            supported_devices: vec![self.info.device.clone()],
1751            memory_requirements: MemoryRequirements {
1752                parameter_memory: 0,
1753                activation_memory_per_token: 0,
1754                kv_cache_memory_per_token: 0,
1755                overhead_memory: 0,
1756            },
1757        }
1758    }
1759
1760    fn release_cache(&self, _: &str) {}
1761
1762    fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
1763        common::default_executor_status()
1764    }
1765}