use candle_core::{DType, Device as CandleDevice, IndexOp, Module, Tensor};
use candle_nn::VarBuilder;
use ferrum_types::{FerrumError, Result};
use tracing::info;
const HIDDEN_SIZE: usize = 512;
const NUM_HEADS: usize = 8;
const NUM_TRANSFORMER_LAYERS: usize = 8;
const SEMANTIC_CODEBOOK_SIZE: usize = 2048;
const ACOUSTIC_CODEBOOK_SIZE: usize = 2048;
const CODEBOOK_DIM: usize = 256;
const NUM_ACOUSTIC_CODEBOOKS: usize = 31;
const NUM_OUTPUT_CODEBOOKS: usize = 16;
pub struct SpeechTokenizerEncoder {
conv_stack: candle_transformers::models::mimi::seanet::SeaNetEncoder,
transformer:
parking_lot::Mutex<candle_transformers::models::mimi::transformer::ProjectedTransformer>,
downsample: candle_transformers::models::mimi::conv::ConvDownsample1d,
quantizer: candle_transformers::models::mimi::quantization::SplitResidualVectorQuantizer,
device: CandleDevice,
}
impl SpeechTokenizerEncoder {
pub fn load(vb: VarBuilder, device: CandleDevice) -> Result<Self> {
let mimi_cfg = candle_transformers::models::mimi::Config::v0_1(Some(NUM_OUTPUT_CODEBOOKS));
let conv_stack = candle_transformers::models::mimi::seanet::SeaNetEncoder::new(
&mimi_cfg.seanet,
vb.pp("encoder"),
)
.map_err(|e| FerrumError::model(format!("encoder conv stack: {e}")))?;
let transformer =
candle_transformers::models::mimi::transformer::ProjectedTransformer::new(
mimi_cfg.seanet.dimension,
&[mimi_cfg.seanet.dimension],
&mimi_cfg.transformer,
vb.pp("encoder_transformer"),
)
.map_err(|e| FerrumError::model(format!("encoder transformer: {e}")))?;
let downsample = candle_transformers::models::mimi::conv::ConvDownsample1d::new(
2, mimi_cfg.seanet.dimension,
true, true, vb.pp("downsample"),
)
.map_err(|e| FerrumError::model(format!("encoder downsample: {e}")))?;
let quantizer =
candle_transformers::models::mimi::quantization::SplitResidualVectorQuantizer::new(
CODEBOOK_DIM,
Some(HIDDEN_SIZE),
Some(HIDDEN_SIZE),
NUM_OUTPUT_CODEBOOKS,
SEMANTIC_CODEBOOK_SIZE,
vb.pp("quantizer"),
)
.map_err(|e| FerrumError::model(format!("encoder quantizer: {e}")))?;
info!(
"SpeechTokenizerEncoder loaded: conv=15 layers (960x ds) + 2x downsample, \
transformer={} layers (h={}, heads={}), \
RVQ=1x{}+{}x{} → {} codebooks",
NUM_TRANSFORMER_LAYERS,
HIDDEN_SIZE,
NUM_HEADS,
SEMANTIC_CODEBOOK_SIZE,
NUM_ACOUSTIC_CODEBOOKS,
ACOUSTIC_CODEBOOK_SIZE,
NUM_OUTPUT_CODEBOOKS,
);
Ok(Self {
conv_stack,
transformer: parking_lot::Mutex::new(transformer),
downsample,
quantizer,
device,
})
}
pub fn encode(&self, pcm: &[f32]) -> Result<Vec<Vec<u32>>> {
let num_samples = pcm.len();
info!(
"SpeechTokenizerEncoder: encoding {} samples ({:.2}s @ 24kHz)",
num_samples,
num_samples as f64 / 24000.0,
);
let input = Tensor::from_vec(pcm.to_vec(), (1, 1, num_samples), &self.device)
.map_err(|e| FerrumError::model(format!("input tensor: {e}")))?
.to_dtype(DType::F32)
.map_err(|e| FerrumError::model(format!("input dtype: {e}")))?;
let conv_out = input
.apply(&self.conv_stack)
.map_err(|e| FerrumError::model(format!("conv encoder: {e}")))?;
let mut transformer = self.transformer.lock();
let hidden = transformer
.forward(&conv_out)
.map_err(|e| FerrumError::model(format!("encoder transformer: {e}")))?;
let hidden = &hidden[0];
let hidden = hidden
.apply(&self.downsample)
.map_err(|e| FerrumError::model(format!("encoder downsample: {e}")))?;
let codes = self
.quantizer
.encode(&hidden)
.map_err(|e| FerrumError::model(format!("quantizer encode: {e}")))?;
let codes = codes
.squeeze(0)
.map_err(|e| FerrumError::model(format!("squeeze: {e}")))?
.transpose(0, 1)
.map_err(|e| FerrumError::model(format!("transpose: {e}")))?
.to_dtype(DType::U32)
.map_err(|e| FerrumError::model(format!("to_u32: {e}")))?;
let t = codes
.dim(0)
.map_err(|e| FerrumError::model(format!("dim: {e}")))?;
let k = codes
.dim(1)
.map_err(|e| FerrumError::model(format!("dim1: {e}")))?;
info!("SpeechTokenizerEncoder: {} frames, {} codebooks", t, k);
let mut result = Vec::with_capacity(t);
for ti in 0..t {
let row: Vec<u32> = codes
.i(ti)
.and_then(|r| r.to_vec1())
.map_err(|e| FerrumError::model(format!("codes row: {e}")))?;
result.push(row);
}
Ok(result)
}
}