ferrum_models/architectures/
speech_tokenizer_encoder.rs1use candle_core::{DType, Device as CandleDevice, IndexOp, Module, Tensor};
12use candle_nn::VarBuilder;
13use ferrum_types::{FerrumError, Result};
14use tracing::info;
15
16const HIDDEN_SIZE: usize = 512;
19const NUM_HEADS: usize = 8;
20const NUM_TRANSFORMER_LAYERS: usize = 8;
21const SEMANTIC_CODEBOOK_SIZE: usize = 2048;
22const ACOUSTIC_CODEBOOK_SIZE: usize = 2048;
23const CODEBOOK_DIM: usize = 256;
24const NUM_ACOUSTIC_CODEBOOKS: usize = 31;
25const NUM_OUTPUT_CODEBOOKS: usize = 16; pub struct SpeechTokenizerEncoder {
34 conv_stack: candle_transformers::models::mimi::seanet::SeaNetEncoder,
35 transformer:
36 parking_lot::Mutex<candle_transformers::models::mimi::transformer::ProjectedTransformer>,
37 downsample: candle_transformers::models::mimi::conv::ConvDownsample1d,
38 quantizer: candle_transformers::models::mimi::quantization::SplitResidualVectorQuantizer,
39 device: CandleDevice,
40}
41
42impl SpeechTokenizerEncoder {
43 pub fn load(vb: VarBuilder, device: CandleDevice) -> Result<Self> {
45 let mimi_cfg = candle_transformers::models::mimi::Config::v0_1(Some(NUM_OUTPUT_CODEBOOKS));
46
47 let conv_stack = candle_transformers::models::mimi::seanet::SeaNetEncoder::new(
48 &mimi_cfg.seanet,
49 vb.pp("encoder"),
50 )
51 .map_err(|e| FerrumError::model(format!("encoder conv stack: {e}")))?;
52
53 let transformer =
54 candle_transformers::models::mimi::transformer::ProjectedTransformer::new(
55 mimi_cfg.seanet.dimension,
56 &[mimi_cfg.seanet.dimension],
57 &mimi_cfg.transformer,
58 vb.pp("encoder_transformer"),
59 )
60 .map_err(|e| FerrumError::model(format!("encoder transformer: {e}")))?;
61
62 let downsample = candle_transformers::models::mimi::conv::ConvDownsample1d::new(
63 2, mimi_cfg.seanet.dimension,
65 true, true, vb.pp("downsample"),
68 )
69 .map_err(|e| FerrumError::model(format!("encoder downsample: {e}")))?;
70
71 let quantizer =
72 candle_transformers::models::mimi::quantization::SplitResidualVectorQuantizer::new(
73 CODEBOOK_DIM,
74 Some(HIDDEN_SIZE),
75 Some(HIDDEN_SIZE),
76 NUM_OUTPUT_CODEBOOKS,
77 SEMANTIC_CODEBOOK_SIZE,
78 vb.pp("quantizer"),
79 )
80 .map_err(|e| FerrumError::model(format!("encoder quantizer: {e}")))?;
81
82 info!(
83 "SpeechTokenizerEncoder loaded: conv=15 layers (960x ds) + 2x downsample, \
84 transformer={} layers (h={}, heads={}), \
85 RVQ=1x{}+{}x{} → {} codebooks",
86 NUM_TRANSFORMER_LAYERS,
87 HIDDEN_SIZE,
88 NUM_HEADS,
89 SEMANTIC_CODEBOOK_SIZE,
90 NUM_ACOUSTIC_CODEBOOKS,
91 ACOUSTIC_CODEBOOK_SIZE,
92 NUM_OUTPUT_CODEBOOKS,
93 );
94
95 Ok(Self {
96 conv_stack,
97 transformer: parking_lot::Mutex::new(transformer),
98 downsample,
99 quantizer,
100 device,
101 })
102 }
103
104 pub fn encode(&self, pcm: &[f32]) -> Result<Vec<Vec<u32>>> {
106 let num_samples = pcm.len();
107 info!(
108 "SpeechTokenizerEncoder: encoding {} samples ({:.2}s @ 24kHz)",
109 num_samples,
110 num_samples as f64 / 24000.0,
111 );
112
113 let input = Tensor::from_vec(pcm.to_vec(), (1, 1, num_samples), &self.device)
114 .map_err(|e| FerrumError::model(format!("input tensor: {e}")))?
115 .to_dtype(DType::F32)
116 .map_err(|e| FerrumError::model(format!("input dtype: {e}")))?;
117
118 let conv_out = input
120 .apply(&self.conv_stack)
121 .map_err(|e| FerrumError::model(format!("conv encoder: {e}")))?;
122
123 let mut transformer = self.transformer.lock();
124 let hidden = transformer
125 .forward(&conv_out)
126 .map_err(|e| FerrumError::model(format!("encoder transformer: {e}")))?;
127 let hidden = &hidden[0];
128
129 let hidden = hidden
130 .apply(&self.downsample)
131 .map_err(|e| FerrumError::model(format!("encoder downsample: {e}")))?;
132
133 let codes = self
134 .quantizer
135 .encode(&hidden)
136 .map_err(|e| FerrumError::model(format!("quantizer encode: {e}")))?;
137
138 let codes = codes
140 .squeeze(0)
141 .map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
142 .transpose(0, 1)
143 .map_err(|e| FerrumError::model(format!("transpose: {e}")))?
144 .to_dtype(DType::U32)
145 .map_err(|e| FerrumError::model(format!("to_u32: {e}")))?;
146
147 let t = codes
148 .dim(0)
149 .map_err(|e| FerrumError::model(format!("dim: {e}")))?;
150 let k = codes
151 .dim(1)
152 .map_err(|e| FerrumError::model(format!("dim1: {e}")))?;
153 info!("SpeechTokenizerEncoder: {} frames, {} codebooks", t, k);
154
155 let mut result = Vec::with_capacity(t);
156 for ti in 0..t {
157 let row: Vec<u32> = codes
158 .i(ti)
159 .and_then(|r| r.to_vec1())
160 .map_err(|e| FerrumError::model(format!("codes row: {e}")))?;
161 result.push(row);
162 }
163 Ok(result)
164 }
165}