Skip to main content

ferrum_models/architectures/
speech_tokenizer_encoder.rs

1//! Mimi-based speech tokenizer encoder for Qwen3-TTS ICL voice cloning.
2//!
3//! Takes raw 24kHz audio and outputs codec token indices [T, 16]
4//! (1 semantic + 15 acoustic codebooks).
5//!
6//! Uses candle-transformers' Mimi components directly (SeaNetEncoder,
7//! ProjectedTransformer, ConvDownsample1d, SplitResidualVectorQuantizer).
8//!
9//! Loaded from speech_tokenizer/model.safetensors.
10
11use candle_core::{DType, Device as CandleDevice, IndexOp, Module, Tensor};
12use candle_nn::VarBuilder;
13use ferrum_types::{FerrumError, Result};
14use tracing::info;
15
16// ── Config ──────────────────────────────────────────────────────────────
17
18const HIDDEN_SIZE: usize = 512;
19const NUM_HEADS: usize = 8;
20const NUM_TRANSFORMER_LAYERS: usize = 8;
21const SEMANTIC_CODEBOOK_SIZE: usize = 2048;
22const ACOUSTIC_CODEBOOK_SIZE: usize = 2048;
23const CODEBOOK_DIM: usize = 256;
24const NUM_ACOUSTIC_CODEBOOKS: usize = 31;
25const NUM_OUTPUT_CODEBOOKS: usize = 16; // 1 semantic + 15 acoustic
26
27// ── SpeechTokenizerEncoder ──────────────────────────────────────────────
28
29/// Mimi-based speech tokenizer encoder: raw 24kHz PCM → codec tokens.
30///
31/// All components use candle-transformers' Mimi implementation for correctness.
32/// This is a cold path (runs once per voice clone), so performance is secondary.
33pub struct SpeechTokenizerEncoder {
34    conv_stack: candle_transformers::models::mimi::seanet::SeaNetEncoder,
35    transformer:
36        parking_lot::Mutex<candle_transformers::models::mimi::transformer::ProjectedTransformer>,
37    downsample: candle_transformers::models::mimi::conv::ConvDownsample1d,
38    quantizer: candle_transformers::models::mimi::quantization::SplitResidualVectorQuantizer,
39    device: CandleDevice,
40}
41
42impl SpeechTokenizerEncoder {
43    /// Load from VarBuilder scoped to `encoder.` prefix.
44    pub fn load(vb: VarBuilder, device: CandleDevice) -> Result<Self> {
45        let mimi_cfg = candle_transformers::models::mimi::Config::v0_1(Some(NUM_OUTPUT_CODEBOOKS));
46
47        let conv_stack = candle_transformers::models::mimi::seanet::SeaNetEncoder::new(
48            &mimi_cfg.seanet,
49            vb.pp("encoder"),
50        )
51        .map_err(|e| FerrumError::model(format!("encoder conv stack: {e}")))?;
52
53        let transformer =
54            candle_transformers::models::mimi::transformer::ProjectedTransformer::new(
55                mimi_cfg.seanet.dimension,
56                &[mimi_cfg.seanet.dimension],
57                &mimi_cfg.transformer,
58                vb.pp("encoder_transformer"),
59            )
60            .map_err(|e| FerrumError::model(format!("encoder transformer: {e}")))?;
61
62        let downsample = candle_transformers::models::mimi::conv::ConvDownsample1d::new(
63            2, // stride: 25Hz → 12.5Hz
64            mimi_cfg.seanet.dimension,
65            true, // causal
66            true, // learnt
67            vb.pp("downsample"),
68        )
69        .map_err(|e| FerrumError::model(format!("encoder downsample: {e}")))?;
70
71        let quantizer =
72            candle_transformers::models::mimi::quantization::SplitResidualVectorQuantizer::new(
73                CODEBOOK_DIM,
74                Some(HIDDEN_SIZE),
75                Some(HIDDEN_SIZE),
76                NUM_OUTPUT_CODEBOOKS,
77                SEMANTIC_CODEBOOK_SIZE,
78                vb.pp("quantizer"),
79            )
80            .map_err(|e| FerrumError::model(format!("encoder quantizer: {e}")))?;
81
82        info!(
83            "SpeechTokenizerEncoder loaded: conv=15 layers (960x ds) + 2x downsample, \
84             transformer={} layers (h={}, heads={}), \
85             RVQ=1x{}+{}x{} → {} codebooks",
86            NUM_TRANSFORMER_LAYERS,
87            HIDDEN_SIZE,
88            NUM_HEADS,
89            SEMANTIC_CODEBOOK_SIZE,
90            NUM_ACOUSTIC_CODEBOOKS,
91            ACOUSTIC_CODEBOOK_SIZE,
92            NUM_OUTPUT_CODEBOOKS,
93        );
94
95        Ok(Self {
96            conv_stack,
97            transformer: parking_lot::Mutex::new(transformer),
98            downsample,
99            quantizer,
100            device,
101        })
102    }
103
104    /// Encode 24kHz mono PCM → codec token indices `[T, 16]`.
105    pub fn encode(&self, pcm: &[f32]) -> Result<Vec<Vec<u32>>> {
106        let num_samples = pcm.len();
107        info!(
108            "SpeechTokenizerEncoder: encoding {} samples ({:.2}s @ 24kHz)",
109            num_samples,
110            num_samples as f64 / 24000.0,
111        );
112
113        let input = Tensor::from_vec(pcm.to_vec(), (1, 1, num_samples), &self.device)
114            .map_err(|e| FerrumError::model(format!("input tensor: {e}")))?
115            .to_dtype(DType::F32)
116            .map_err(|e| FerrumError::model(format!("input dtype: {e}")))?;
117
118        // Conv encoder → Transformer → Downsample → Quantize
119        let conv_out = input
120            .apply(&self.conv_stack)
121            .map_err(|e| FerrumError::model(format!("conv encoder: {e}")))?;
122
123        let mut transformer = self.transformer.lock();
124        let hidden = transformer
125            .forward(&conv_out)
126            .map_err(|e| FerrumError::model(format!("encoder transformer: {e}")))?;
127        let hidden = &hidden[0];
128
129        let hidden = hidden
130            .apply(&self.downsample)
131            .map_err(|e| FerrumError::model(format!("encoder downsample: {e}")))?;
132
133        let codes = self
134            .quantizer
135            .encode(&hidden)
136            .map_err(|e| FerrumError::model(format!("quantizer encode: {e}")))?;
137
138        // [1, 16, T] → Vec<Vec<u32>> as [T, 16]
139        let codes = codes
140            .squeeze(0)
141            .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
142            .transpose(0, 1)
143            .map_err(|e| FerrumError::model(format!("transpose: {e}")))?
144            .to_dtype(DType::U32)
145            .map_err(|e| FerrumError::model(format!("to_u32: {e}")))?;
146
147        let t = codes
148            .dim(0)
149            .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
150        let k = codes
151            .dim(1)
152            .map_err(|e| FerrumError::model(format!("dim1: {e}")))?;
153        info!("SpeechTokenizerEncoder: {} frames, {} codebooks", t, k);
154
155        let mut result = Vec::with_capacity(t);
156        for ti in 0..t {
157            let row: Vec<u32> = codes
158                .i(ti)
159                .and_then(|r| r.to_vec1())
160                .map_err(|e| FerrumError::model(format!("codes row: {e}")))?;
161            result.push(row);
162        }
163        Ok(result)
164    }
165}