Skip to main content

ferrum_models/executor/
tts_executor.rs

1//! Qwen3-TTS Executor — text-to-speech pipeline wiring Talker LM + Vocoder.
2//!
3//! Implements: text tokenization, autoregressive codec token generation,
4//! SubTalker code prediction (TODO), vocoder waveform synthesis.
5
6#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use candle_core::{DType, Device as CandleDevice, Tensor};
13use candle_nn::VarBuilder;
14use ferrum_interfaces::{
15    model_executor::{
16        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, MemoryRequirements,
17        PrefillInput, PrefillOutput,
18    },
19    ModelExecutor, TensorRef,
20};
21use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, Result};
22use tracing::info;
23
24use super::common;
25use crate::architectures::qwen3_tts::{Qwen3TTSTalker, SubTalker, TalkerConfig};
26use crate::architectures::qwen3_tts_backbone::TalkerBackboneBackend;
27use crate::architectures::qwen3_tts_vocoder::{Qwen3TTSVocoder, VocoderConfig};
28use crate::architectures::speaker_encoder::{mel_spectrogram_speaker_encoder, SpeakerEncoder};
29use crate::architectures::speech_tokenizer_encoder::SpeechTokenizerEncoder;
30use ferrum_quantization::NativeSafetensorsLoader;
31
32/// Install `Qwen3TtsTalker`/`SubTalker` Backend<CudaBackend> overrides so
33/// the transformer stack runs via ferrum-kernels cuBLAS + CUDA kernels
34/// instead of the broken fused-on-Linux CPU fallback.
35#[cfg(feature = "cuda")]
36fn install_cuda_backend_overrides(
37    cfg: &TalkerConfig,
38    model_dir: &std::path::Path,
39    talker: &mut Qwen3TTSTalker,
40    sub_talker: &mut SubTalker,
41) -> Result<()> {
42    use ferrum_kernels::backend::cuda::CudaBackend;
43    let loader: NativeSafetensorsLoader<CudaBackend> = NativeSafetensorsLoader::open(model_dir)?;
44    let talker_bb = TalkerBackboneBackend::<CudaBackend>::new(cfg, &loader)?;
45    talker.set_backend_override(Box::new(talker_bb));
46    let sub_bb = TalkerBackboneBackend::<CudaBackend>::new_code_predictor(cfg, &loader)?;
47    sub_talker.set_backend_override(Box::new(sub_bb));
48    Ok(())
49}
50
51// ── Constants ────────────────────────────────────────────────────────────
52
53const SAMPLE_RATE: usize = 24000;
54const MAX_CODEC_TOKENS: usize = 2000;
55
56/// Sampling parameters for codec token generation.
57/// FERRUM_TTS_TEMP env var overrides (0.0 = greedy, 0.9 = default sampling)
58const TEMPERATURE: f32 = 0.9;
59const TOP_K: usize = 50;
60const REPETITION_PENALTY: f32 = 1.05;
61
62fn tts_temperature() -> f32 {
63    std::env::var("FERRUM_TTS_TEMP")
64        .ok()
65        .and_then(|v| v.parse().ok())
66        .unwrap_or(TEMPERATURE)
67}
68
69fn st_temperature() -> f32 {
70    std::env::var("FERRUM_ST_TEMP")
71        .ok()
72        .and_then(|v| v.parse().ok())
73        .unwrap_or(tts_temperature())
74}
75
76/// Qwen3-TTS executor: text-to-speech synthesis.
77pub struct TtsModelExecutor {
78    talker: Qwen3TTSTalker,
79    sub_talker: SubTalker,
80    vocoder: Qwen3TTSVocoder,
81    text_tokenizer: tokenizers::Tokenizer,
82    config: TalkerConfig,
83    info: ModelInfo,
84    speaker_encoder: Option<SpeakerEncoder>,
85    speech_tokenizer_encoder: Option<SpeechTokenizerEncoder>,
86}
87
88impl TtsModelExecutor {
89    /// Load from model directory containing:
90    /// - config.json (TalkerConfig)
91    /// - model.safetensors (Talker weights)
92    /// - speech_tokenizer/model.safetensors (Vocoder weights)
93    /// - tokenizer_config.json + vocab.json + merges.txt (text tokenizer)
94    pub fn from_path(model_path: &str, device: CandleDevice, dtype: DType) -> Result<Self> {
95        let dir = std::path::Path::new(model_path);
96
97        // Parse TalkerConfig from config.json
98        let config_json: serde_json::Value = {
99            let config_path = dir.join("config.json");
100            let data = std::fs::read_to_string(&config_path)
101                .map_err(|e| FerrumError::model(format!("read config.json: {e}")))?;
102            serde_json::from_str(&data)
103                .map_err(|e| FerrumError::model(format!("parse config.json: {e}")))?
104        };
105        let config = TalkerConfig::from_json(&config_json)?;
106
107        // Load text tokenizer from vocab.json + merges.txt
108        let text_tokenizer = load_bpe_tokenizer(dir)?;
109
110        // Load Talker weights from model.safetensors (or sharded)
111        let talker_weights = find_safetensor_files(dir, "model")?;
112        let talker_vb = unsafe {
113            VarBuilder::from_mmaped_safetensors(&talker_weights, dtype, &device)
114                .map_err(|e| FerrumError::model(format!("load talker weights: {e}")))?
115        };
116        let mut talker = Qwen3TTSTalker::load(&config, talker_vb.clone(), device.clone())?;
117
118        // Load SubTalker (code predictor) from same weights file
119        let mut sub_talker = SubTalker::load(&config, talker_vb.clone(), device.clone())?;
120
121        // Load Speaker Encoder (for voice cloning, base models only)
122        let spk_enc_dim = config_json
123            .get("speaker_encoder_config")
124            .and_then(|c| c.get("enc_dim"))
125            .and_then(|v| v.as_u64())
126            .unwrap_or(1024) as usize;
127        let speaker_encoder =
128            SpeakerEncoder::load_with_dim(talker_vb.pp("speaker_encoder"), spk_enc_dim)
129                .map_err(|e| {
130                    tracing::warn!("Speaker encoder not available: {e}");
131                    e
132                })
133                .ok();
134
135        // Load Vocoder weights from speech_tokenizer/model.safetensors
136        let vocoder_dir = dir.join("speech_tokenizer");
137        let vocoder_weights = find_safetensor_files(&vocoder_dir, "model")?;
138        let vocoder_vb = unsafe {
139            VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &device)
140                .map_err(|e| FerrumError::model(format!("load vocoder weights: {e}")))?
141        };
142        let vocoder_config = VocoderConfig::default();
143        let vocoder = Qwen3TTSVocoder::load(&vocoder_config, vocoder_vb.clone())?;
144
145        // Load Speech Tokenizer Encoder on CPU — Metal float32 accumulation order
146        // causes transformer output divergence that amplifies through RVQ codebook lookup.
147        // CPU is exact and encoder only runs once per reference audio.
148        let speech_tokenizer_encoder = if vocoder_dir.join("config.json").exists() {
149            let cpu_vb = unsafe {
150                VarBuilder::from_mmaped_safetensors(&vocoder_weights, dtype, &CandleDevice::Cpu)
151                    .map_err(|e| FerrumError::model(format!("load encoder cpu: {e}")))?
152            };
153            SpeechTokenizerEncoder::load(cpu_vb.pp("encoder"), CandleDevice::Cpu)
154                .map_err(|e| {
155                    tracing::warn!("Speech tokenizer encoder not available: {e}");
156                    e
157                })
158                .ok()
159        } else {
160            None
161        };
162
163        // Install a Backend<B>-backed transformer for the Talker/SubTalker
164        // when running on CUDA. The legacy `ferrum_attention::FusedTransformer`
165        // CUDA module is a stub and its Linux CPU fallback uses naive fp64
166        // matmul — the CUDA-only voice-clone regression traces back to that
167        // path accumulating precision drift through the 20-layer decoder.
168        // Routing through LlamaFamilyModel<CudaBackend> gives us cuBLAS +
169        // ferrum-kernels for the transformer stack while keeping candle
170        // embeddings / projection / codec_head unchanged.
171        #[cfg(feature = "cuda")]
172        if matches!(&device, CandleDevice::Cuda(_)) {
173            match install_cuda_backend_overrides(&config, dir, &mut talker, &mut sub_talker) {
174                Ok(()) => {
175                    tracing::info!(
176                        "TtsModelExecutor: Backend<CudaBackend> installed for Talker + SubTalker"
177                    );
178                }
179                Err(e) => {
180                    tracing::warn!(
181                        "TtsModelExecutor: Backend<CudaBackend> install failed ({e}); \
182                         falling back to candle/fused path (CUDA voice-clone may produce garbage)"
183                    );
184                }
185            }
186        }
187
188        let info = ModelInfo {
189            model_id: ferrum_types::ModelId(model_path.to_string()),
190            model_type: ModelType::Custom("qwen3-tts".to_string()),
191            hidden_size: config.hidden_size,
192            vocab_size: config.vocab_size,
193            num_layers: config.num_hidden_layers,
194            num_heads: config.num_attention_heads,
195            num_kv_heads: config.num_key_value_heads,
196            num_parameters: 0,
197            max_sequence_length: config.max_position_embeddings,
198            device: match &device {
199                CandleDevice::Cpu => Device::CPU,
200                CandleDevice::Cuda(_) => Device::CUDA(0),
201                #[cfg(any(target_os = "macos", target_os = "ios"))]
202                CandleDevice::Metal(_) => Device::Metal,
203                #[cfg(not(any(target_os = "macos", target_os = "ios")))]
204                CandleDevice::Metal(_) => Device::CPU,
205            },
206            dtype: match dtype {
207                DType::F32 => DataType::FP32,
208                DType::F16 => DataType::FP16,
209                DType::BF16 => DataType::BF16,
210                _ => DataType::FP32,
211            },
212            version: None,
213            license: None,
214            metadata: HashMap::new(),
215        };
216
217        info!(
218            "TtsModelExecutor: {} (hidden={}, layers={}, codec_groups={})",
219            model_path, config.hidden_size, config.num_hidden_layers, config.num_code_groups,
220        );
221
222        Ok(Self {
223            talker,
224            sub_talker,
225            vocoder,
226            text_tokenizer,
227            config,
228            info,
229            speaker_encoder,
230            speech_tokenizer_encoder,
231        })
232    }
233
234    /// Synthesize speech from text.
235    ///
236    /// Returns PCM samples at 24kHz as Vec<f32>.
237    ///
238    /// Prompt structure (matches Python/qwen3-tts-rs):
239    ///   Prefill: [role_prefix(3)] + [tts_text_prefix(6) + codec_prefix(6)] + [first_text + codec_bos]
240    ///   Trailing: text_projection(remaining_text + tts_eos) — added per decode step
241    pub fn synthesize(&mut self, text: &str, language: &str) -> Result<Vec<f32>> {
242        self.talker.reset();
243
244        let device = self.talker.device().clone();
245
246        // 1. Tokenize text (raw content only, no chat template)
247        let encoding = self
248            .text_tokenizer
249            .encode(text, false)
250            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
251        let content_ids: Vec<u32> = encoding.get_ids().to_vec();
252
253        if content_ids.is_empty() {
254            return Err(FerrumError::model("empty text after tokenization"));
255        }
256
257        info!("TTS: content tokens = {}", content_ids.len());
258
259        let codec_eos = self.config.codec_eos_token_id;
260        let tts_pad = self.config.tts_pad_token_id;
261        let tts_bos = self.config.tts_bos_token_id;
262        let tts_eos = self.config.tts_eos_token_id;
263
264        // Helper: embed text token IDs through text_embedding + text_projection
265        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
266            let t = Tensor::new(ids, &device)
267                .and_then(|t| t.unsqueeze(0))
268                .map_err(|e| FerrumError::model(format!("text ids: {e}")))?;
269            self.talker.embed_text(&t)
270        };
271        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
272            let t = Tensor::new(ids, &device)
273                .and_then(|t| t.unsqueeze(0))
274                .map_err(|e| FerrumError::model(format!("codec ids: {e}")))?;
275            self.talker.embed_codec(&t)
276        };
277
278        // 2. Build prefill (matching qwen3-tts-rs prefill_custom_voice)
279
280        // Role prefix: text_projection([<|im_start|>, assistant, \n])
281        // These are fixed special token IDs from the tokenizer
282        let im_start_id = 151644u32; // <|im_start|>
283        let assistant_id = 77091u32; // "assistant"
284        let newline_id = 198u32; // "\n"
285        let role_prefix_ids = [im_start_id, assistant_id, newline_id];
286
287        info!(
288            "TTS: role_prefix={} content={} tokens",
289            role_prefix_ids.len(),
290            content_ids.len()
291        );
292
293        // Role prefix embedding (text_projection)
294        let role_embed = embed_text_ids(&role_prefix_ids)?;
295
296        // Codec prefix: [think, think_bos, lang, think_eos, speaker, pad]
297        let resolved_lang = if language == "auto" {
298            "chinese"
299        } else {
300            language
301        };
302        let language_id = self
303            .config
304            .codec_language_id
305            .get(&resolved_lang.to_lowercase());
306        let codec_prefix_ids = if let Some(&lang_id) = language_id {
307            vec![
308                self.config.codec_think_id,
309                self.config.codec_think_bos_id,
310                lang_id,
311                self.config.codec_think_eos_id,
312            ]
313        } else {
314            vec![
315                self.config.codec_nothink_id,
316                self.config.codec_think_bos_id,
317                self.config.codec_think_eos_id,
318            ]
319        };
320        // Codec sequence: [think, think_bos, lang, think_eos, SPEAKER, pad, bos]
321        // Speaker: use Vivian (3065) for Chinese, Ryan (3061) for English
322        let speaker_token = if resolved_lang == "chinese" {
323            3065u32
324        } else {
325            3061u32
326        };
327        let codec_full = {
328            let mut v = codec_prefix_ids.clone();
329            v.push(speaker_token);
330            v.push(self.config.codec_pad_id);
331            v.push(self.config.codec_bos_id);
332            v
333        };
334        let codec_embed = embed_codec_ids(&codec_full)?;
335
336        // tts_text_prefix: [tts_pad * (codec_len-1), tts_bos]
337        let n_codec = codec_full.len();
338        let mut tts_prefix_ids = vec![tts_pad; n_codec - 1];
339        tts_prefix_ids.push(tts_bos);
340        let tts_prefix_embed = embed_text_ids(&tts_prefix_ids)?;
341
342        // Sum: tts_prefix + codec_prefix[:-1]
343        let codec_first = codec_embed
344            .narrow(1, 0, n_codec - 1)
345            .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
346        // Actually we need codec_first to have same length as tts_prefix
347        // tts_prefix has n_codec elements, codec_first has n_codec-1
348        // Let me re-read the reference... codec_embed has 6 tokens [think..pad,bos], tts_prefix has 6 [pad*5,bos]
349        // They sum first 6 codec with first 6 tts_prefix, then codec_bos is separate
350        // codec_full has [think, think_bos, lang, think_eos, speaker, pad, bos] = 7 tokens
351        // tts_prefix: [pad*5, bos] overlaid on codec[0:6] (first 6, excluding last bos)
352        // Then: first_text + codec_bos (last token) summed
353        let n_prefix = n_codec - 1; // 6: everything except codec_bos
354        let codec_prefix_part = codec_embed
355            .narrow(1, 0, n_prefix)
356            .map_err(|e| FerrumError::model(format!("codec narrow: {e}")))?;
357
358        // tts text prefix: [pad * (n_prefix-1), bos]
359        let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
360        tts_text_prefix_ids.push(tts_bos);
361        let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
362
363        let codec_hidden = (&tts_text_embed + &codec_prefix_part)
364            .map_err(|e| FerrumError::model(format!("prefix sum: {e}")))?;
365
366        // codec_bos is the last element of codec_full
367        let codec_bos_embed = codec_embed
368            .narrow(1, n_prefix, 1)
369            .map_err(|e| FerrumError::model(format!("codec bos: {e}")))?;
370
371        // First text token + codec_bos (summed)
372        let first_text_combined = if !content_ids.is_empty() {
373            let first_text_embed = embed_text_ids(&content_ids[..1])?;
374            (&first_text_embed + &codec_bos_embed)
375                .map_err(|e| FerrumError::model(format!("first text+bos: {e}")))?
376        } else {
377            codec_bos_embed.clone()
378        };
379
380        // Full prefill: [role_prefix, codec_hidden, first_text+codec_bos]
381        let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
382            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
383
384        let plen = prefill_embeds.dim(1).unwrap_or(0);
385        info!("TTS: prefill_len = {}", plen);
386        // Dump prefill input for comparison with reference
387        if let Ok(v) = prefill_embeds
388            .narrow(0, 0, 1)
389            .and_then(|t| t.narrow(1, 0, 1))
390            .and_then(|t| t.narrow(2, 0, 5))
391            .and_then(|t| t.flatten_all())
392            .and_then(|t| t.to_vec1::<f32>())
393        {
394            info!("  prefill_input pos0[:5] = {:?}", v);
395        }
396        if plen > 0 {
397            if let Ok(v) = prefill_embeds
398                .narrow(0, 0, 1)
399                .and_then(|t| t.narrow(1, plen - 1, 1))
400                .and_then(|t| t.narrow(2, 0, 5))
401                .and_then(|t| t.flatten_all())
402                .and_then(|t| t.to_vec1::<f32>())
403            {
404                info!("  prefill_input pos-1[:5] = {:?}", v);
405            }
406        }
407
408        // 3. Build trailing text: text_projection(remaining_content + tts_eos)
409        let mut trailing_ids: Vec<u32> = if content_ids.len() > 1 {
410            content_ids[1..].to_vec()
411        } else {
412            Vec::new()
413        };
414        trailing_ids.push(tts_eos);
415        let trailing_text_embeds = embed_text_ids(&trailing_ids)?;
416        let trailing_text_len = trailing_text_embeds
417            .dim(1)
418            .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
419        let tts_pad_embed = embed_text_ids(&[tts_pad])?;
420
421        info!("TTS: trailing_text_len = {}", trailing_text_len);
422
423        // 4. Prefill: forward through transformer
424        let mut hidden = self.talker.forward_step(&prefill_embeds)?;
425
426        // Get logits from last position
427        let current_logits = self.talker.logits(
428            &hidden
429                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
430                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
431        )?;
432
433        // 4. Autoregressive decode loop: generate codec token 0 per step
434        let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
435        let mut current_logits = current_logits;
436
437        // Token suppression: mask [vocab_size-1024, vocab_size) except EOS
438        let suppress_start = self.config.vocab_size.saturating_sub(1024);
439        let suppress_end = self.config.vocab_size;
440        let mut generated_tokens: Vec<u32> = Vec::new();
441
442        for step in 0..MAX_CODEC_TOKENS {
443            // Sample next codec token with suppression + repetition penalty
444            let mut logits_vec = logits_to_vec(&current_logits)?;
445            // Suppress special tokens
446            for i in suppress_start..suppress_end.min(logits_vec.len()) {
447                if i as u32 != codec_eos {
448                    logits_vec[i] = f32::NEG_INFINITY;
449                }
450            }
451            // Repetition penalty (matching Python's repetition_penalty=1.05)
452            for &prev_tok in &generated_tokens {
453                let idx = prev_tok as usize;
454                if idx < logits_vec.len() {
455                    if logits_vec[idx] > 0.0 {
456                        logits_vec[idx] /= REPETITION_PENALTY;
457                    } else {
458                        logits_vec[idx] *= REPETITION_PENALTY;
459                    }
460                }
461            }
462            let next_token =
463                sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
464
465            if step < 10 {
466                // Find argmax (greedy token)
467                let argmax_tok = logits_vec
468                    .iter()
469                    .enumerate()
470                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
471                    .map(|(i, v)| (i, *v))
472                    .unwrap_or((0, 0.0));
473                info!(
474                    "TOKEN step={} sampled={} argmax=({}, {:.2})",
475                    step, next_token, argmax_tok.0, argmax_tok.1
476                );
477            }
478
479            generated_tokens.push(next_token);
480
481            // Check for EOS
482            if next_token == codec_eos {
483                info!("TTS: codec EOS at step {}", step);
484                break;
485            }
486
487            // Get last hidden state from talker for SubTalker
488            let last_hidden = hidden
489                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
490                .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
491
492            // Embed first codec token
493            let token_tensor = Tensor::new(&[next_token], &device)
494                .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
495                .unsqueeze(0)
496                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
497            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
498
499            // SubTalker: predict remaining codec tokens 1..num_code_groups-1
500            let st_t0 = std::time::Instant::now();
501            let extra_codes = self.sub_talker.predict(
502                &last_hidden,
503                &first_codec_embed,
504                st_temperature(),
505                TOP_K,
506            )?;
507            if step == 0 {
508                info!(
509                    "  SubTalker: {:.1}ms",
510                    st_t0.elapsed().as_secs_f64() * 1000.0
511                );
512            }
513
514            let mut frame_codes = vec![next_token];
515            frame_codes.extend_from_slice(&extra_codes);
516            all_codec_tokens.push(frame_codes);
517
518            // Build combined embedding for next talker step:
519            // sum of all codec embeddings (token 0 from main talker + tokens 1-15 from sub-talker)
520            let mut combined_embed = first_codec_embed.clone();
521            for (i, &code) in extra_codes.iter().enumerate() {
522                let code_t = Tensor::new(&[code], &device)
523                    .and_then(|t| t.unsqueeze(0))
524                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
525                let sub_embed = code_t
526                    .apply(&self.sub_talker.codec_embeddings[i])
527                    .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
528                combined_embed = (combined_embed + sub_embed)
529                    .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
530            }
531
532            if step == 0 {
533                if let Ok(v) = combined_embed
534                    .flatten_all()
535                    .and_then(|t| t.narrow(0, 0, 5))
536                    .and_then(|t| t.to_vec1::<f32>())
537                {
538                    info!("STEP0 codec_sum[:5] = {:?} (before trailing)", v);
539                }
540            }
541            // Add trailing text embedding (guides generation toward target text)
542            // Python: inputs_embeds = codec_sum + trailing_text_hidden[:, gen_step]
543            // trailing_text = text_projection(remaining_text) + tts_eos
544            // For basic TTS, trailing covers all text tokens after the first
545            if step < trailing_text_len {
546                let trail = trailing_text_embeds
547                    .narrow(1, step, 1)
548                    .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
549                combined_embed = (combined_embed + trail)
550                    .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
551            } else {
552                combined_embed = (combined_embed + &tts_pad_embed)
553                    .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
554            }
555
556            // Debug: dump step 0 components
557            if step == 0 {
558                if let Ok(v) = first_codec_embed
559                    .flatten_all()
560                    .and_then(|t| t.narrow(0, 0, 5))
561                    .and_then(|t| t.to_vec1::<f32>())
562                {
563                    info!("STEP0 semantic[:5] = {:?}", v);
564                }
565                if let Ok(v) = combined_embed
566                    .flatten_all()
567                    .and_then(|t| t.narrow(0, 0, 5))
568                    .and_then(|t| t.to_vec1::<f32>())
569                {
570                    info!("STEP0 combined[:5] = {:?}", v);
571                }
572            }
573            // Forward combined embedding through talker
574            let tk_t0 = std::time::Instant::now();
575            hidden = self.talker.forward_step(&combined_embed)?;
576            current_logits = self.talker.logits(&hidden)?;
577            if step == 0 {
578                info!(
579                    "  Talker step: {:.1}ms",
580                    tk_t0.elapsed().as_secs_f64() * 1000.0
581                );
582            }
583        }
584
585        if all_codec_tokens.is_empty() {
586            return Err(FerrumError::model("no codec tokens generated"));
587        }
588
589        info!("TTS: generated {} codec frames", all_codec_tokens.len());
590
591        // 5. Build codec tensor [1, num_code_groups, T] for vocoder
592        let num_frames = all_codec_tokens.len();
593        let num_groups = self.config.num_code_groups;
594        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
595        for (t, frame) in all_codec_tokens.iter().enumerate() {
596            for (g, &code) in frame.iter().enumerate() {
597                flat_codes[g * num_frames + t] = code;
598            }
599        }
600
601        // Clamp codes to valid codebook range [0, codebook_size-1].
602        // Special tokens (codec_bos, codec_eos, etc.) are >= codebook_size and must be clamped.
603        let codebook_size = 2048u32;
604        for code in &mut flat_codes {
605            if *code >= codebook_size {
606                *code = 0; // replace special tokens with pad/silence
607            }
608        }
609
610        let codes_tensor = Tensor::new(&flat_codes[..], &device)
611            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
612            .reshape((1, num_groups, num_frames))
613            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
614
615        // 6. Vocoder: codec tokens → waveform
616        let waveform = self.vocoder.decode(&codes_tensor)?;
617
618        // Extract samples: [1, 1, samples] → Vec<f32>
619        let samples: Vec<f32> = waveform
620            .squeeze(0)
621            .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
622            .squeeze(0)
623            .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
624            .to_vec1()
625            .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
626
627        info!(
628            "TTS: waveform {} samples ({:.2}s @ {}Hz)",
629            samples.len(),
630            samples.len() as f64 / SAMPLE_RATE as f64,
631            SAMPLE_RATE,
632        );
633
634        Ok(samples)
635    }
636
637    /// Decode a chunk of codec frames to audio samples.
638    /// frames: Vec<Vec<u32>> where each inner Vec has num_code_groups elements.
639    fn decode_frames(&mut self, frames: &[Vec<u32>], device: &CandleDevice) -> Result<Vec<f32>> {
640        let num_frames = frames.len();
641        if num_frames == 0 {
642            return Ok(vec![]);
643        }
644        let num_groups = self.config.num_code_groups;
645        let codebook_size = 2048u32;
646
647        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
648        for (t, frame) in frames.iter().enumerate() {
649            for (g, &code) in frame.iter().take(num_groups).enumerate() {
650                flat_codes[g * num_frames + t] = if code >= codebook_size { 0 } else { code };
651            }
652        }
653
654        let codes_tensor = Tensor::new(&flat_codes[..], device)
655            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
656            .reshape((1, num_groups, num_frames))
657            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
658
659        let waveform = self.vocoder.decode(&codes_tensor)?;
660        waveform
661            .squeeze(0)
662            .and_then(|t| t.squeeze(0))
663            .and_then(|t| t.to_vec1())
664            .map_err(|e| FerrumError::model(format!("waveform extract: {e}")))
665    }
666
667    /// Streaming TTS: calls `on_chunk` with each audio chunk as soon as it's ready.
668    ///
669    /// Each chunk is `chunk_frames` codec frames decoded to audio (~800ms at default 10 frames).
670    /// First chunk arrives after `chunk_frames` decode steps (~2-3s for 0.6B).
671    pub fn synthesize_streaming<F: FnMut(usize, &[f32])>(
672        &mut self,
673        text: &str,
674        language: &str,
675        chunk_frames: usize,
676        mut on_chunk: F,
677    ) -> Result<Vec<Vec<f32>>> {
678        // Reuse existing synthesize setup (prefill + trailing text)
679        // but yield audio in chunks instead of all at once
680        self.talker.reset();
681        let device = self.talker.device().clone();
682
683        let encoding = self
684            .text_tokenizer
685            .encode(text, false)
686            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
687        let content_ids: Vec<u32> = encoding.get_ids().to_vec();
688        if content_ids.is_empty() {
689            return Err(FerrumError::model("empty text after tokenization"));
690        }
691
692        let codec_eos = self.config.codec_eos_token_id;
693        let tts_pad = self.config.tts_pad_token_id;
694        let tts_bos = self.config.tts_bos_token_id;
695        let tts_eos = self.config.tts_eos_token_id;
696
697        // Build embeddings (same as synthesize)
698        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
699            let t = Tensor::new(ids, &device)
700                .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
701                .unsqueeze(0)
702                .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
703            self.talker.embed_text(&t)
704        };
705        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
706            let t = Tensor::new(ids, &device)
707                .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
708                .unsqueeze(0)
709                .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
710            self.talker.embed_codec(&t)
711        };
712
713        // Codec/text prefix (same as synthesize)
714        let resolved_lang = if language.eq_ignore_ascii_case("auto") {
715            "chinese"
716        } else {
717            language
718        };
719        let language_id = self
720            .config
721            .codec_language_id
722            .get(&resolved_lang.to_lowercase());
723        let codec_prefix_ids = if let Some(&lang_id) = language_id {
724            vec![
725                self.config.codec_think_id,
726                self.config.codec_think_bos_id,
727                lang_id,
728                self.config.codec_think_eos_id,
729            ]
730        } else {
731            vec![
732                self.config.codec_nothink_id,
733                self.config.codec_think_bos_id,
734                self.config.codec_think_eos_id,
735            ]
736        };
737        let speaker_token = if resolved_lang == "chinese" {
738            3065u32
739        } else {
740            3061u32
741        };
742        let mut codec_ids = codec_prefix_ids;
743        codec_ids.push(speaker_token);
744        codec_ids.push(self.config.codec_pad_id);
745        codec_ids.push(self.config.codec_bos_id);
746        let codec_embed = embed_codec_ids(&codec_ids)?;
747        let n_codec = codec_embed
748            .dim(1)
749            .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
750        let n_prefix = n_codec - 1;
751        let codec_prefix_part = codec_embed
752            .narrow(1, 0, n_prefix)
753            .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
754
755        let mut tts_text_prefix_ids = vec![tts_pad; n_prefix - 1];
756        tts_text_prefix_ids.push(tts_bos);
757        let tts_text_embed = embed_text_ids(&tts_text_prefix_ids)?;
758        let codec_hidden = (&tts_text_embed + &codec_prefix_part)
759            .map_err(|e| FerrumError::model(format!("sum: {e}")))?;
760        let codec_bos_embed = codec_embed
761            .narrow(1, n_prefix, 1)
762            .map_err(|e| FerrumError::model(format!("bos: {e}")))?;
763
764        let role_ids: &[u32] = &[151644, 77091, 198]; // im_start, assistant, \n
765        let role_embed = embed_text_ids(role_ids)?;
766
767        let first_text_combined = if !content_ids.is_empty() {
768            let first_text_embed = embed_text_ids(&content_ids[..1])?;
769            (&first_text_embed + &codec_bos_embed)
770                .map_err(|e| FerrumError::model(format!("first: {e}")))?
771        } else {
772            codec_bos_embed.clone()
773        };
774
775        let prefill_embeds = Tensor::cat(&[&role_embed, &codec_hidden, &first_text_combined], 1)
776            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
777
778        // Trailing text
779        let trailing_text_embeds = if content_ids.len() > 1 {
780            let remaining = embed_text_ids(&content_ids[1..])?;
781            let eos = embed_text_ids(&[tts_eos])?;
782            Tensor::cat(&[&remaining, &eos], 1)
783                .map_err(|e| FerrumError::model(format!("trailing: {e}")))?
784        } else {
785            embed_text_ids(&[tts_eos])?
786        };
787        let trailing_text_len = trailing_text_embeds
788            .dim(1)
789            .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
790        let tts_pad_embed = embed_text_ids(&[tts_pad])?;
791
792        // Prefill
793        let mut hidden = self.talker.forward_step(&prefill_embeds)?;
794        let mut current_logits = self.talker.logits(
795            &hidden
796                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
797                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?,
798        )?;
799
800        // Streaming decode loop
801        let suppress_start = self.config.vocab_size.saturating_sub(1024);
802        let suppress_end = self.config.vocab_size;
803        let mut generated_tokens: Vec<u32> = Vec::new();
804        let mut frame_buffer: Vec<Vec<u32>> = Vec::new();
805        let mut audio_chunks: Vec<Vec<f32>> = Vec::new();
806
807        for step in 0..MAX_CODEC_TOKENS {
808            let mut logits_vec = logits_to_vec(&current_logits)?;
809            for i in suppress_start..suppress_end.min(logits_vec.len()) {
810                if i as u32 != codec_eos {
811                    logits_vec[i] = f32::NEG_INFINITY;
812                }
813            }
814            for &prev_tok in &generated_tokens {
815                let idx = prev_tok as usize;
816                if idx < logits_vec.len() {
817                    if logits_vec[idx] > 0.0 {
818                        logits_vec[idx] /= REPETITION_PENALTY;
819                    } else {
820                        logits_vec[idx] *= REPETITION_PENALTY;
821                    }
822                }
823            }
824            let next_token =
825                sample_token(&logits_vec, tts_temperature(), TOP_K, REPETITION_PENALTY);
826            generated_tokens.push(next_token);
827
828            if next_token == codec_eos {
829                info!("TTS streaming: EOS at step {}", step);
830                break;
831            }
832
833            let last_hidden = hidden
834                .narrow(1, hidden.dim(1).unwrap() - 1, 1)
835                .map_err(|e| FerrumError::model(format!("narrow: {e}")))?;
836            let token_tensor = Tensor::new(&[next_token], &device)
837                .and_then(|t| t.unsqueeze(0))
838                .map_err(|e| FerrumError::model(format!("tok: {e}")))?;
839            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
840
841            let extra_codes = self.sub_talker.predict(
842                &last_hidden,
843                &first_codec_embed,
844                st_temperature(),
845                TOP_K,
846            )?;
847
848            let mut frame = vec![next_token];
849            frame.extend_from_slice(&extra_codes);
850            frame_buffer.push(frame);
851
852            // Emit chunk when buffer is full
853            if frame_buffer.len() >= chunk_frames {
854                let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
855                on_chunk(audio_chunks.len(), &chunk_audio);
856                audio_chunks.push(chunk_audio);
857                frame_buffer.clear();
858            }
859
860            // Build combined embed for next step
861            let mut combined_embed = first_codec_embed.clone();
862            for (i, &code) in extra_codes.iter().enumerate() {
863                let code_t = Tensor::new(&[code], &device)
864                    .and_then(|t| t.unsqueeze(0))
865                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
866                let sub_embed = code_t
867                    .apply(&self.sub_talker.codec_embeddings[i])
868                    .map_err(|e| FerrumError::model(format!("sub: {e}")))?;
869                combined_embed = (combined_embed + sub_embed)
870                    .map_err(|e| FerrumError::model(format!("add: {e}")))?;
871            }
872            if step < trailing_text_len {
873                let trail = trailing_text_embeds
874                    .narrow(1, step, 1)
875                    .map_err(|e| FerrumError::model(format!("trail: {e}")))?;
876                combined_embed = (combined_embed + trail)
877                    .map_err(|e| FerrumError::model(format!("add trail: {e}")))?;
878            } else {
879                combined_embed = (combined_embed + &tts_pad_embed)
880                    .map_err(|e| FerrumError::model(format!("pad: {e}")))?;
881            }
882
883            hidden = self.talker.forward_step(&combined_embed)?;
884            current_logits = self.talker.logits(&hidden)?;
885        }
886
887        // Flush remaining frames
888        if !frame_buffer.is_empty() {
889            let chunk_audio = self.decode_frames(&frame_buffer, &device)?;
890            on_chunk(audio_chunks.len(), &chunk_audio);
891            audio_chunks.push(chunk_audio);
892        }
893
894        Ok(audio_chunks)
895    }
896
897    /// Get the output sample rate.
898    pub fn sample_rate(&self) -> usize {
899        SAMPLE_RATE
900    }
901
902    pub fn config(&self) -> &TalkerConfig {
903        &self.config
904    }
905
906    /// Synthesize speech with voice cloning from a reference audio.
907    ///
908    /// Uses ICL (in-context learning) prompting: the reference audio is
909    /// encoded to codec tokens and prepended to the generation prompt,
910    /// along with a speaker embedding extracted via ECAPA-TDNN.
911    ///
912    /// Returns PCM samples at 24kHz as Vec<f32>.
913    pub fn synthesize_voice_clone(
914        &mut self,
915        text: &str,
916        language: &str,
917        ref_audio_path: &str,
918        ref_text: &str,
919    ) -> Result<Vec<f32>> {
920        let device = self.talker.device().clone();
921
922        // Step 1: Load and process reference audio at 24kHz
923        let ref_pcm = if let Ok(path) = std::env::var("FERRUM_REF_PCM") {
924            let data = std::fs::read(&path)
925                .map_err(|e| FerrumError::model(format!("read ref pcm: {e}")))?;
926            let pcm: Vec<f32> = data
927                .chunks(4)
928                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
929                .collect();
930            info!("Loaded ref PCM override: {} samples", pcm.len());
931            pcm
932        } else {
933            crate::audio_processor::load_audio_at_rate(ref_audio_path, 24000)?
934        };
935        info!(
936            "TTS voice clone: loaded ref audio {} samples ({:.2}s)",
937            ref_pcm.len(),
938            ref_pcm.len() as f64 / 24000.0
939        );
940
941        let t0 = std::time::Instant::now();
942        // Step 2: Extract speaker embedding via ECAPA-TDNN
943        let speaker_encoder = self
944            .speaker_encoder
945            .as_ref()
946            .ok_or_else(|| FerrumError::model("speaker encoder not loaded"))?;
947        let mel = mel_spectrogram_speaker_encoder(&ref_pcm);
948        let n_mel_frames = mel.len() / 128;
949        // mel_spectrogram_speaker_encoder returns [T, 128] row-major
950        // forward() internally transposes [1, T, 128] → [1, 128, T]
951        let mel_tensor = Tensor::from_vec(mel, (1, n_mel_frames, 128), &device)
952            .map_err(|e| FerrumError::model(format!("mel tensor: {e}")))?;
953        let spk_embed = speaker_encoder.forward(&mel_tensor)?;
954        info!(
955            "Step 2 (speaker embed): {:.1}ms",
956            t0.elapsed().as_secs_f64() * 1000.0
957        );
958        // spk_embed shape: [enc_dim] -> reshape to [1, 1, hidden_size]
959        let spk_embed = spk_embed
960            .unsqueeze(0)
961            .map_err(|e| FerrumError::model(format!("spk unsqueeze(0): {e}")))?
962            .unsqueeze(0)
963            .map_err(|e| FerrumError::model(format!("spk unsqueeze(0) 2: {e}")))?;
964
965        let t1 = std::time::Instant::now();
966        // Step 3: Encode reference audio to codec tokens (ICL)
967        let speech_enc = self
968            .speech_tokenizer_encoder
969            .as_ref()
970            .ok_or_else(|| FerrumError::model("speech tokenizer encoder not loaded"))?;
971        // Allow pre-computed codec tokens for debugging (FERRUM_REF_CODES=/path/to/codes.bin)
972        let ref_codes = if let Ok(path) = std::env::var("FERRUM_REF_CODES") {
973            let data = std::fs::read(&path)
974                .map_err(|e| FerrumError::model(format!("read ref codes: {e}")))?;
975            let u32s: Vec<u32> = data
976                .chunks(4)
977                .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
978                .collect();
979            let ncb = self.config.num_code_groups;
980            let nframes = u32s.len() / ncb;
981            info!(
982                "Loaded pre-computed ref codes: {} frames from {}",
983                nframes, path
984            );
985            u32s.chunks(ncb).map(|c| c.to_vec()).collect()
986        } else {
987            let codes = speech_enc.encode(&ref_pcm)?;
988            info!(
989                "Step 3 (speech tokenizer): {:.1}ms",
990                t1.elapsed().as_secs_f64() * 1000.0
991            );
992            codes
993        };
994        let ref_frames = ref_codes.len();
995        info!(
996            "TTS voice clone: ref_frames={}, spk_embed loaded",
997            ref_frames
998        );
999        // Debug: dump first 5 codec frames for comparison with Python
1000        for i in 0..ref_frames.min(5) {
1001            info!("  rust codec frame {}: {:?}", i, &ref_codes[i]);
1002        }
1003
1004        info!(
1005            "Step 3 (speech tokenizer): {:.1}ms",
1006            t1.elapsed().as_secs_f64() * 1000.0
1007        );
1008        let t2 = std::time::Instant::now();
1009        // Step 4: Tokenize target text with chat template
1010        let chat_text = format!("<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n");
1011        let encoding = self
1012            .text_tokenizer
1013            .encode(chat_text.as_str(), false)
1014            .map_err(|e| FerrumError::model(format!("tokenize: {e}")))?;
1015        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
1016        // role = input_ids[..3], text_content = input_ids[3..input_ids.len()-5]
1017        let role_ids = &input_ids[..3];
1018        let text_content_ids = &input_ids[3..input_ids.len().saturating_sub(5)];
1019
1020        // Tokenize ref_text
1021        let ref_chat_text = format!("<|im_start|>assistant\n{ref_text}<|im_end|>\n");
1022        let ref_encoding = self
1023            .text_tokenizer
1024            .encode(ref_chat_text.as_str(), false)
1025            .map_err(|e| FerrumError::model(format!("tokenize ref: {e}")))?;
1026        let ref_ids: Vec<u32> = ref_encoding.get_ids().to_vec();
1027        // ref text content: ref_ids[3..ref_ids.len()-2]
1028        let ref_text_ids = &ref_ids[3..ref_ids.len().saturating_sub(2)];
1029
1030        // Step 5: Build prefill prompt (dual-stream text+codec summing)
1031        self.talker.reset();
1032
1033        let tts_bos = self.config.tts_bos_token_id;
1034        let tts_eos = self.config.tts_eos_token_id;
1035        let tts_pad = self.config.tts_pad_token_id;
1036        let codec_bos = self.config.codec_bos_id;
1037        let codec_eos = self.config.codec_eos_token_id;
1038        let codec_pad = self.config.codec_pad_id;
1039
1040        // Helper: embed codec and text tokens
1041        let embed_codec_ids = |ids: &[u32]| -> Result<Tensor> {
1042            let t = Tensor::new(ids, &device)
1043                .map_err(|e| FerrumError::model(format!("codec tensor: {e}")))?
1044                .unsqueeze(0)
1045                .map_err(|e| FerrumError::model(format!("codec unsqueeze: {e}")))?;
1046            self.talker.embed_codec(&t)
1047        };
1048        let embed_text_ids = |ids: &[u32]| -> Result<Tensor> {
1049            let t = Tensor::new(ids, &device)
1050                .map_err(|e| FerrumError::model(format!("text tensor: {e}")))?
1051                .unsqueeze(0)
1052                .map_err(|e| FerrumError::model(format!("text unsqueeze: {e}")))?;
1053            self.talker.embed_text(&t)
1054        };
1055
1056        // Get tts special embeddings
1057        let tts_special = embed_text_ids(&[tts_bos, tts_eos, tts_pad])?;
1058        let tts_bos_embed = tts_special
1059            .narrow(1, 0, 1)
1060            .map_err(|e| FerrumError::model(format!("tts_bos narrow: {e}")))?;
1061        let tts_eos_embed = tts_special
1062            .narrow(1, 1, 1)
1063            .map_err(|e| FerrumError::model(format!("tts_eos narrow: {e}")))?;
1064        let tts_pad_embed = tts_special
1065            .narrow(1, 2, 1)
1066            .map_err(|e| FerrumError::model(format!("tts_pad narrow: {e}")))?;
1067
1068        // Resolve language_id — "auto" defaults to "chinese"
1069        let resolved_lang = if language.eq_ignore_ascii_case("auto") {
1070            "chinese"
1071        } else {
1072            language
1073        };
1074        let language_id = self
1075            .config
1076            .codec_language_id
1077            .get(&resolved_lang.to_lowercase());
1078
1079        // Codec prefix: [think, think_bos, lang, think_eos] or [nothink, think_bos, think_eos]
1080        let codec_prefix_ids = if let Some(&lang_id) = language_id {
1081            vec![
1082                self.config.codec_think_id,
1083                self.config.codec_think_bos_id,
1084                lang_id,
1085                self.config.codec_think_eos_id,
1086            ]
1087        } else {
1088            vec![
1089                self.config.codec_nothink_id,
1090                self.config.codec_think_bos_id,
1091                self.config.codec_think_eos_id,
1092            ]
1093        };
1094        let codec_prefix_embed = embed_codec_ids(&codec_prefix_ids)?;
1095
1096        // Codec suffix: [pad, bos]
1097        let codec_suffix_embed = embed_codec_ids(&[codec_pad, codec_bos])?;
1098
1099        // Speaker embed inserted between prefix and suffix
1100        let codec_input = Tensor::cat(&[&codec_prefix_embed, &spk_embed, &codec_suffix_embed], 1)
1101            .map_err(|e| FerrumError::model(format!("codec_input cat: {e}")))?;
1102        let codec_len = codec_input
1103            .dim(1)
1104            .map_err(|e| FerrumError::model(format!("codec_len dim: {e}")))?;
1105
1106        // Role embedding
1107        let role_embed = embed_text_ids(role_ids)?;
1108
1109        // Text-codec prefix: (codec_len - 2) pads + tts_bos, summed with codec[:-1]
1110        let n_pads = codec_len - 2;
1111        let mut text_prefix_parts = Vec::new();
1112        for _ in 0..n_pads {
1113            text_prefix_parts.push(tts_pad_embed.clone());
1114        }
1115        text_prefix_parts.push(tts_bos_embed.clone());
1116        let text_prefix_refs: Vec<&Tensor> = text_prefix_parts.iter().collect();
1117        let text_prefix = Tensor::cat(&text_prefix_refs, 1)
1118            .map_err(|e| FerrumError::model(format!("text_prefix cat: {e}")))?;
1119        let codec_prefix_part = codec_input
1120            .narrow(1, 0, codec_len - 1)
1121            .map_err(|e| FerrumError::model(format!("codec prefix narrow: {e}")))?;
1122        let text_codec_prefix = (&text_prefix + &codec_prefix_part)
1123            .map_err(|e| FerrumError::model(format!("text+codec prefix sum: {e}")))?;
1124
1125        // ICL mode: prefill is 9 positions (no first_text+codec_bos)
1126        let prefill_embed = Tensor::cat(&[&role_embed, &text_codec_prefix], 1)
1127            .map_err(|e| FerrumError::model(format!("prefill cat: {e}")))?;
1128
1129        let t3 = std::time::Instant::now();
1130
1131        // Step 6b: Build ICL block — [ref_text, target_text, tts_eos] + [codec_bos, ref_codec]
1132        // Streaming mode: element-wise sum, trailing = text beyond codec length
1133        let all_text_ids: Vec<u32> = ref_text_ids
1134            .iter()
1135            .chain(text_content_ids.iter())
1136            .copied()
1137            .collect();
1138        let text_embed = embed_text_ids(&all_text_ids)?;
1139        let text_embed_with_eos = Tensor::cat(&[&text_embed, &tts_eos_embed], 1)
1140            .map_err(|e| FerrumError::model(format!("text+eos cat: {e}")))?;
1141        let text_len = text_embed_with_eos
1142            .dim(1)
1143            .map_err(|e| FerrumError::model(format!("text_len dim: {e}")))?;
1144
1145        // Codec stream: batch all codebook embeddings in one shot
1146        // Collect all first-codebook IDs → single batch embed
1147        let t_codec_start = std::time::Instant::now();
1148        let ncg = self.config.num_code_groups;
1149        let first_codes: Vec<u32> = ref_codes.iter().map(|f| f[0]).collect();
1150        let codec_frames_cat = {
1151            // Batch embed main codec: [1, nframes, hidden]
1152            let mut sum = embed_codec_ids(&first_codes)?;
1153            // Batch embed each sub-codebook and accumulate
1154            for cb in 0..(ncg - 1) {
1155                let codes: Vec<u32> = ref_codes.iter().map(|f| f[cb + 1]).collect();
1156                let codes_t = Tensor::new(codes.as_slice(), &device)
1157                    .map_err(|e| FerrumError::model(format!("batch codes: {e}")))?
1158                    .unsqueeze(0)
1159                    .map_err(|e| FerrumError::model(format!("batch unsqueeze: {e}")))?;
1160                let sub_embed = codes_t
1161                    .apply(&self.sub_talker.codec_embeddings[cb])
1162                    .map_err(|e| FerrumError::model(format!("batch sub_embed: {e}")))?;
1163                sum =
1164                    (sum + sub_embed).map_err(|e| FerrumError::model(format!("batch add: {e}")))?;
1165            }
1166            sum
1167        };
1168        info!(
1169            "Codec embedding: {:.1}ms ({} frames × {} codebooks)",
1170            t_codec_start.elapsed().as_secs_f64() * 1000.0,
1171            ref_codes.len(),
1172            ncg
1173        );
1174
1175        // Prepend codec_bos to codec frames: [codec_bos_embed, codec_frames]
1176        let t_merge = std::time::Instant::now();
1177        let codec_bos_for_icl = embed_codec_ids(&[codec_bos])?;
1178        let icl_codec = Tensor::cat(&[&codec_bos_for_icl, &codec_frames_cat], 1)
1179            .map_err(|e| FerrumError::model(format!("icl_codec cat: {e}")))?;
1180        let codec_icl_len = icl_codec
1181            .dim(1)
1182            .map_err(|e| FerrumError::model(format!("codec_icl_len dim: {e}")))?;
1183
1184        // Build ICL embed: element-wise sum of text and codec (streaming mode)
1185        let icl_trailing: Tensor;
1186        let icl_embed: Tensor;
1187        if text_len > codec_icl_len {
1188            let text_part = text_embed_with_eos
1189                .narrow(1, 0, codec_icl_len)
1190                .map_err(|e| FerrumError::model(format!("text_part narrow: {e}")))?;
1191            icl_embed = (&text_part + &icl_codec)
1192                .map_err(|e| FerrumError::model(format!("text+codec sum: {e}")))?;
1193            icl_trailing = text_embed_with_eos
1194                .narrow(1, codec_icl_len, text_len - codec_icl_len)
1195                .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1196        } else {
1197            let n_pad = codec_icl_len - text_len;
1198            let text_padded = if n_pad > 0 {
1199                let pad_block = tts_pad_embed
1200                    .expand((1, n_pad, self.config.hidden_size))
1201                    .map_err(|e| FerrumError::model(format!("pad expand: {e}")))?;
1202                Tensor::cat(&[&text_embed_with_eos, &pad_block], 1)
1203                    .map_err(|e| FerrumError::model(format!("text_padded cat: {e}")))?
1204            } else {
1205                text_embed_with_eos.clone()
1206            };
1207            icl_embed = (&text_padded + &icl_codec)
1208                .map_err(|e| FerrumError::model(format!("padded+codec sum: {e}")))?;
1209            icl_trailing = tts_pad_embed.clone();
1210        }
1211        let trailing_text_len = icl_trailing
1212            .dim(1)
1213            .map_err(|e| FerrumError::model(format!("trailing dim: {e}")))?;
1214
1215        // Debug: dump values for comparison with reference project
1216        // Step 6c: Run prefill then ICL block as SEPARATE forward passes
1217        let _prefill_out = self.talker.forward_step(&prefill_embed)?;
1218        let t_icl = std::time::Instant::now();
1219        let icl_hidden = self.talker.forward_step(&icl_embed)?;
1220        let icl_len = icl_hidden
1221            .dim(1)
1222            .map_err(|e| FerrumError::model(format!("icl_hidden dim: {e}")))?;
1223        info!(
1224            "ICL block: {:.1}ms ({} tokens), trailing={}",
1225            t_icl.elapsed().as_secs_f64() * 1000.0,
1226            icl_len,
1227            trailing_text_len
1228        );
1229
1230        // Use ICL hidden output for logits and decode
1231        let mut hidden = icl_hidden;
1232        let hidden_len = hidden
1233            .dim(1)
1234            .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1235        let last_hidden = hidden
1236            .narrow(1, hidden_len - 1, 1)
1237            .map_err(|e| FerrumError::model(format!("narrow last: {e}")))?;
1238        if let Ok(v) = last_hidden.flatten_all().and_then(|t| t.to_vec1::<f32>()) {}
1239        let current_logits = self.talker.logits(&last_hidden)?;
1240        {}
1241
1242        // Decode loop
1243        let mut all_codec_tokens: Vec<Vec<u32>> = Vec::new();
1244        let mut current_logits = current_logits;
1245
1246        // Suppress special tokens [vocab_size-1024, vocab_size) except EOS
1247        let suppress_start = self.config.vocab_size.saturating_sub(1024);
1248        let suppress_end = self.config.vocab_size;
1249
1250        // ICL mode: stronger repetition penalty (matching reference Rust project)
1251        // + repetition detection for early stop.
1252        const ICL_REPETITION_PENALTY: f32 = 1.5;
1253        const ICL_FRAMES_PER_TOKEN: usize = 6;
1254        const ICL_MIN_FRAMES: usize = 75;
1255        let max_icl_tokens = ICL_MIN_FRAMES.max(text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1256        let mut generated_tokens: Vec<u32> = Vec::new();
1257
1258        for step in 0..max_icl_tokens {
1259            let mut logits_vec = logits_to_vec(&current_logits)?;
1260            // Suppress special tokens [vocab-1024, vocab) except EOS
1261            for i in suppress_start..suppress_end.min(logits_vec.len()) {
1262                if i as u32 != codec_eos {
1263                    logits_vec[i] = f32::NEG_INFINITY;
1264                }
1265            }
1266            // min_new_tokens: suppress EOS until we've generated a minimum
1267            // number of frames. Reference Python uses sentence-length heuristic.
1268            // FERRUM_TTS_MIN_FRAMES env lets us tune without rebuilding.
1269            let min_frames: usize = std::env::var("FERRUM_TTS_MIN_FRAMES")
1270                .ok()
1271                .and_then(|v| v.parse().ok())
1272                .unwrap_or_else(|| text_content_ids.len() * ICL_FRAMES_PER_TOKEN);
1273            if step < min_frames {
1274                if let Some(v) = logits_vec.get_mut(codec_eos as usize) {
1275                    *v = f32::NEG_INFINITY;
1276                }
1277            }
1278            // Repetition penalty with token history
1279            for &prev_tok in &generated_tokens {
1280                let idx = prev_tok as usize;
1281                if idx < logits_vec.len() {
1282                    if logits_vec[idx] > 0.0 {
1283                        logits_vec[idx] /= ICL_REPETITION_PENALTY;
1284                    } else {
1285                        logits_vec[idx] *= ICL_REPETITION_PENALTY;
1286                    }
1287                }
1288            }
1289            let next_token = sample_token(
1290                &logits_vec,
1291                tts_temperature(),
1292                TOP_K,
1293                ICL_REPETITION_PENALTY,
1294            );
1295
1296            generated_tokens.push(next_token);
1297
1298            if next_token == codec_eos {
1299                info!("TTS voice clone: codec EOS at step {}", step);
1300                break;
1301            }
1302
1303            // Repetition detection: check for repeating patterns of length 1-4
1304            if generated_tokens.len() >= 6 {
1305                let n = generated_tokens.len();
1306                let mut is_repeat = false;
1307                for pat_len in 1..=4 {
1308                    if n >= pat_len * 3 {
1309                        let a = &generated_tokens[n - pat_len * 3..n - pat_len * 2];
1310                        let b = &generated_tokens[n - pat_len * 2..n - pat_len];
1311                        let c = &generated_tokens[n - pat_len..n];
1312                        if a == b && b == c {
1313                            is_repeat = true;
1314                            break;
1315                        }
1316                    }
1317                }
1318                if is_repeat {
1319                    info!(
1320                        "TTS voice clone: repetition detected at step {}, stopping",
1321                        step
1322                    );
1323                    break;
1324                }
1325            }
1326
1327            let cur_hidden_len = hidden
1328                .dim(1)
1329                .map_err(|e| FerrumError::model(format!("hidden dim: {e}")))?;
1330            let last_hidden = hidden
1331                .narrow(1, cur_hidden_len - 1, 1)
1332                .map_err(|e| FerrumError::model(format!("last_hidden: {e}")))?;
1333
1334            let token_tensor = Tensor::new(&[next_token], &device)
1335                .map_err(|e| FerrumError::model(format!("token tensor: {e}")))?
1336                .unsqueeze(0)
1337                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
1338            let first_codec_embed = self.talker.embed_codec(&token_tensor)?;
1339
1340            let extra_codes = self.sub_talker.predict(
1341                &last_hidden,
1342                &first_codec_embed,
1343                st_temperature(),
1344                TOP_K,
1345            )?;
1346
1347            let mut frame_codes = vec![next_token];
1348            frame_codes.extend_from_slice(&extra_codes);
1349            all_codec_tokens.push(frame_codes);
1350
1351            // Sum all codebook embeddings on GPU
1352            let mut combined_embed = first_codec_embed.clone();
1353            for (i, &code) in extra_codes.iter().enumerate() {
1354                let code_t = Tensor::new(&[code], &device)
1355                    .and_then(|t| t.unsqueeze(0))
1356                    .map_err(|e| FerrumError::model(format!("code_t: {e}")))?;
1357                let sub_embed = code_t
1358                    .apply(&self.sub_talker.codec_embeddings[i])
1359                    .map_err(|e| FerrumError::model(format!("sub_embed: {e}")))?;
1360                combined_embed = (combined_embed + sub_embed)
1361                    .map_err(|e| FerrumError::model(format!("add embed: {e}")))?;
1362            }
1363
1364            // Add trailing text or tts_pad (matching reference streaming mode)
1365            if step < trailing_text_len {
1366                let trail = icl_trailing
1367                    .narrow(1, step, 1)
1368                    .map_err(|e| FerrumError::model(format!("trailing narrow: {e}")))?;
1369                combined_embed = (combined_embed + trail)
1370                    .map_err(|e| FerrumError::model(format!("add trailing: {e}")))?;
1371            } else {
1372                combined_embed = (combined_embed + &tts_pad_embed)
1373                    .map_err(|e| FerrumError::model(format!("add tts_pad: {e}")))?;
1374            }
1375
1376            hidden = self.talker.forward_step(&combined_embed)?;
1377            current_logits = self.talker.logits(&hidden)?;
1378        }
1379
1380        if all_codec_tokens.is_empty() {
1381            return Err(FerrumError::model("no codec tokens generated"));
1382        }
1383        info!(
1384            "TTS voice clone: generated {} codec frames",
1385            all_codec_tokens.len()
1386        );
1387
1388        // Step 7: Prepend ref codes and decode with vocoder
1389        let mut all_codes_with_ref = ref_codes.clone();
1390        all_codes_with_ref.extend_from_slice(&all_codec_tokens);
1391
1392        let num_frames = all_codes_with_ref.len();
1393        let num_groups = self.config.num_code_groups;
1394
1395        // Build codec tensor [1, num_groups, T]
1396        let mut flat_codes: Vec<u32> = vec![0; num_groups * num_frames];
1397        for (t, frame) in all_codes_with_ref.iter().enumerate() {
1398            for (g, &code) in frame.iter().take(num_groups).enumerate() {
1399                flat_codes[g * num_frames + t] = code;
1400            }
1401        }
1402
1403        // Clamp to valid range
1404        let codebook_size = 2048u32;
1405        for code in &mut flat_codes {
1406            if *code >= codebook_size {
1407                *code = 0;
1408            }
1409        }
1410
1411        let codes_tensor = Tensor::new(&flat_codes[..], &device)
1412            .map_err(|e| FerrumError::model(format!("codes tensor: {e}")))?
1413            .reshape((1, num_groups, num_frames))
1414            .map_err(|e| FerrumError::model(format!("reshape codes: {e}")))?;
1415
1416        let waveform = self.vocoder.decode(&codes_tensor)?;
1417
1418        let samples: Vec<f32> = waveform
1419            .squeeze(0)
1420            .map_err(|e| FerrumError::model(format!("squeeze batch: {e}")))?
1421            .squeeze(0)
1422            .map_err(|e| FerrumError::model(format!("squeeze channel: {e}")))?
1423            .to_vec1()
1424            .map_err(|e| FerrumError::model(format!("to_vec1: {e}")))?;
1425
1426        // Trim reference portion
1427        let ref_ratio = ref_frames as f64 / num_frames as f64;
1428        let cut = (ref_ratio * samples.len() as f64) as usize;
1429        let output_samples = samples[cut..].to_vec();
1430
1431        info!(
1432            "TTS voice clone: waveform {} samples ({:.2}s), trimmed ref {} samples",
1433            output_samples.len(),
1434            output_samples.len() as f64 / SAMPLE_RATE as f64,
1435            cut,
1436        );
1437
1438        Ok(output_samples)
1439    }
1440}
1441
1442// ── Utility functions ───────────────────────────────────────────────────
1443
1444/// Find safetensor files matching a prefix in a directory.
1445fn find_safetensor_files(dir: &std::path::Path, prefix: &str) -> Result<Vec<std::path::PathBuf>> {
1446    // Try single file first
1447    let single = dir.join(format!("{prefix}.safetensors"));
1448    if single.exists() {
1449        return Ok(vec![single]);
1450    }
1451
1452    // Try sharded: prefix-00001-of-00005.safetensors, etc.
1453    let mut files: Vec<std::path::PathBuf> = Vec::new();
1454    if let Ok(entries) = std::fs::read_dir(dir) {
1455        for entry in entries.flatten() {
1456            let path = entry.path();
1457            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1458                if name.starts_with(prefix)
1459                    && name.ends_with(".safetensors")
1460                    && name != format!("{prefix}.safetensors")
1461                {
1462                    files.push(path);
1463                }
1464            }
1465        }
1466    }
1467    files.sort();
1468
1469    if files.is_empty() {
1470        Err(FerrumError::model(format!(
1471            "no safetensors files with prefix '{prefix}' in {}",
1472            dir.display()
1473        )))
1474    } else {
1475        Ok(files)
1476    }
1477}
1478
1479/// Load a BPE tokenizer from vocab.json + merges.txt.
1480fn load_bpe_tokenizer(dir: &std::path::Path) -> Result<tokenizers::Tokenizer> {
1481    // Try tokenizer.json first (HF fast tokenizer format)
1482    let tokenizer_json = dir.join("tokenizer.json");
1483    if tokenizer_json.exists() {
1484        return tokenizers::Tokenizer::from_file(&tokenizer_json)
1485            .map_err(|e| FerrumError::model(format!("load tokenizer.json: {e}")));
1486    }
1487
1488    // Fallback: build from vocab.json + merges.txt
1489    let vocab_path = dir.join("vocab.json");
1490    let merges_path = dir.join("merges.txt");
1491
1492    if !vocab_path.exists() || !merges_path.exists() {
1493        return Err(FerrumError::model(
1494            "tokenizer.json not found, and vocab.json + merges.txt not found either",
1495        ));
1496    }
1497
1498    let vocab_data = std::fs::read_to_string(&vocab_path)
1499        .map_err(|e| FerrumError::model(format!("read vocab.json: {e}")))?;
1500    let vocab: HashMap<String, u32> = serde_json::from_str(&vocab_data)
1501        .map_err(|e| FerrumError::model(format!("parse vocab.json: {e}")))?;
1502
1503    let merges_data = std::fs::read_to_string(&merges_path)
1504        .map_err(|e| FerrumError::model(format!("read merges.txt: {e}")))?;
1505    let merges: Vec<(String, String)> = merges_data
1506        .lines()
1507        .skip(1) // skip header line
1508        .filter(|line| !line.is_empty())
1509        .filter_map(|line| {
1510            let parts: Vec<&str> = line.splitn(2, ' ').collect();
1511            if parts.len() == 2 {
1512                Some((parts[0].to_string(), parts[1].to_string()))
1513            } else {
1514                None
1515            }
1516        })
1517        .collect();
1518
1519    let bpe = tokenizers::models::bpe::BPE::from_file(
1520        vocab_path.to_str().unwrap(),
1521        merges_path.to_str().unwrap(),
1522    )
1523    .build()
1524    .map_err(|e| FerrumError::model(format!("build BPE: {e}")))?;
1525
1526    let tokenizer = tokenizers::Tokenizer::new(bpe);
1527    Ok(tokenizer)
1528}
1529
1530/// Extract logits as Vec<f32> from a [1, 1, vocab] or [1, vocab] tensor.
1531fn logits_to_vec(logits: &Tensor) -> Result<Vec<f32>> {
1532    let logits = if logits.dims().len() == 3 {
1533        logits
1534            .squeeze(0)
1535            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1536            .squeeze(0)
1537            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1538    } else if logits.dims().len() == 2 {
1539        logits
1540            .squeeze(0)
1541            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
1542    } else {
1543        logits.clone()
1544    };
1545
1546    logits
1547        .to_vec1()
1548        .map_err(|e| FerrumError::model(format!("logits to_vec1: {e}")))
1549}
1550
1551/// Sample a token from logits with temperature, top-k, and repetition penalty.
1552/// Sample a token matching qwen3-tts-rs reference:
1553/// 1. temperature scaling
1554/// 2. top-k filter (keep top_k, rest = -inf)
1555/// 3. top-p filter (keep smallest set with cumprob > top_p, rest = -inf)
1556/// 4. softmax over filtered logits
1557/// 5. multinomial sample from distribution
1558pub fn sample_token(
1559    logits: &[f32],
1560    temperature: f32,
1561    top_k: usize,
1562    _repetition_penalty: f32,
1563) -> u32 {
1564    if temperature < 0.01 {
1565        return argmax(logits);
1566    }
1567
1568    let vocab = logits.len();
1569
1570    // 1. Apply temperature
1571    let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
1572
1573    // 2. Top-k filter: keep top_k values, set rest to -inf
1574    let mut filtered = scaled.clone();
1575    if top_k > 0 && top_k < vocab {
1576        let mut sorted = scaled.clone();
1577        sorted.sort_unstable_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1578        let threshold = sorted[top_k - 1];
1579        for v in &mut filtered {
1580            if *v < threshold {
1581                *v = f32::NEG_INFINITY;
1582            }
1583        }
1584    }
1585
1586    // 3. Top-p filter: keep smallest set of tokens whose cumulative prob >= top_p
1587    const TOP_P: f32 = 0.9;
1588    {
1589        let mut indices: Vec<usize> = (0..vocab).collect();
1590        indices.sort_unstable_by(|&a, &b| {
1591            filtered[b]
1592                .partial_cmp(&filtered[a])
1593                .unwrap_or(std::cmp::Ordering::Equal)
1594        });
1595
1596        // Softmax over sorted values for cumulative prob
1597        let max_val = filtered[indices[0]];
1598        let exp_sorted: Vec<f32> = indices
1599            .iter()
1600            .map(|&i| (filtered[i] - max_val).exp())
1601            .collect();
1602        let sum: f32 = exp_sorted.iter().sum();
1603        let probs_sorted: Vec<f32> = exp_sorted.iter().map(|e| e / sum).collect();
1604
1605        // Find cutoff
1606        let mut cumsum = 0.0f32;
1607        let mut cutoff_idx = vocab;
1608        for (i, &p) in probs_sorted.iter().enumerate() {
1609            cumsum += p;
1610            if cumsum > TOP_P {
1611                cutoff_idx = i + 1;
1612                break;
1613            }
1614        }
1615
1616        // Mask out tokens beyond cutoff
1617        for &idx in &indices[cutoff_idx..] {
1618            filtered[idx] = f32::NEG_INFINITY;
1619        }
1620    }
1621
1622    // 4. Softmax over filtered logits
1623    let max_val = filtered.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1624    let exps: Vec<f32> = filtered.iter().map(|&v| (v - max_val).exp()).collect();
1625    let sum: f32 = exps.iter().sum();
1626    let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
1627
1628    // 5. Multinomial sample
1629    let r = rand_f32();
1630    let mut cumulative = 0.0f32;
1631    for (i, &p) in probs.iter().enumerate() {
1632        cumulative += p;
1633        if cumulative >= r {
1634            return i as u32;
1635        }
1636    }
1637    // Fallback
1638    argmax(&probs)
1639}
1640
1641fn argmax(v: &[f32]) -> u32 {
1642    v.iter()
1643        .enumerate()
1644        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1645        .map(|(i, _)| i as u32)
1646        .unwrap_or(0)
1647}
1648
1649/// RNG matching qwen3-tts-rs: LCG with subsec_nanos seed + counter
1650fn rand_f32() -> f32 {
1651    use std::sync::atomic::{AtomicU64, Ordering};
1652    static COUNTER: AtomicU64 = AtomicU64::new(0);
1653
1654    let seed = std::time::SystemTime::now()
1655        .duration_since(std::time::UNIX_EPOCH)
1656        .unwrap_or_default()
1657        .subsec_nanos() as u64;
1658    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
1659
1660    let state = seed
1661        .wrapping_add(count)
1662        .wrapping_mul(1103515245)
1663        .wrapping_add(12345);
1664    (state as f32) / (u64::MAX as f32)
1665}
1666
1667// ── Dummy KV cache + ModelExecutor trait impl ───────────────────────────
1668
1669#[derive(Clone, Debug)]
1670#[allow(dead_code)]
1671struct DummyTtsCache;
1672
1673impl ferrum_interfaces::KvCacheHandle for DummyTtsCache {
1674    fn block_table(&self) -> &ferrum_interfaces::BlockTable {
1675        static EMPTY: std::sync::OnceLock<ferrum_interfaces::BlockTable> =
1676            std::sync::OnceLock::new();
1677        EMPTY.get_or_init(|| ferrum_interfaces::BlockTable::new(16))
1678    }
1679    fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
1680        unimplemented!()
1681    }
1682    fn as_any(&self) -> &dyn std::any::Any {
1683        self
1684    }
1685    fn device(&self) -> Device {
1686        Device::CPU
1687    }
1688    fn num_layers(&self) -> usize {
1689        0
1690    }
1691    fn num_heads(&self) -> usize {
1692        0
1693    }
1694    fn head_dim(&self) -> usize {
1695        0
1696    }
1697    fn key_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1698        Ok(None)
1699    }
1700    fn value_cache(&self, _: usize) -> Result<Option<TensorRef>> {
1701        Ok(None)
1702    }
1703    fn clone_handle(&self) -> Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
1704        Ok(Arc::new(self.clone()))
1705    }
1706    fn stats(&self) -> ferrum_interfaces::CacheHandleStats {
1707        ferrum_interfaces::CacheHandleStats {
1708            memory_bytes: 0,
1709            blocks_allocated: 0,
1710            tokens_stored: 0,
1711            utilization: 0.0,
1712            last_access: std::time::Instant::now(),
1713        }
1714    }
1715    fn is_valid(&self) -> bool {
1716        true
1717    }
1718    fn cache_id(&self) -> String {
1719        "tts_dummy".to_string()
1720    }
1721}
1722
1723#[async_trait]
1724impl ModelExecutor for TtsModelExecutor {
1725    fn info(&self) -> &ModelInfo {
1726        &self.info
1727    }
1728
1729    async fn prefill(&self, _input: &PrefillInput) -> Result<PrefillOutput> {
1730        Err(FerrumError::model(
1731            "TTS uses synthesize(), not prefill/decode",
1732        ))
1733    }
1734
1735    async fn decode(&self, _input: &DecodeInput) -> Result<DecodeOutput> {
1736        Err(FerrumError::model(
1737            "TTS uses synthesize(), not prefill/decode",
1738        ))
1739    }
1740
1741    fn capabilities(&self) -> ExecutorCapabilities {
1742        ExecutorCapabilities {
1743            max_batch_size: 1,
1744            max_sequence_length: self.info.max_sequence_length,
1745            attention_mechanisms: vec![AttentionType::GroupedQuery],
1746            supports_dynamic_batching: false,
1747            supports_continuous_batching: false,
1748            supports_speculative_decoding: false,
1749            supports_tensor_parallelism: false,
1750            supports_pipeline_parallelism: false,
1751            supported_dtypes: vec![DataType::FP32, DataType::BF16],
1752            supported_devices: vec![self.info.device.clone()],
1753            memory_requirements: MemoryRequirements {
1754                parameter_memory: 0,
1755                activation_memory_per_token: 0,
1756                kv_cache_memory_per_token: 0,
1757                overhead_memory: 0,
1758            },
1759        }
1760    }
1761
1762    fn release_cache(&self, _: &str) {}
1763
1764    fn status(&self) -> ferrum_interfaces::model_executor::ExecutorStatus {
1765        common::default_executor_status()
1766    }
1767}