use super::vector_quantizer::{VQConfig, VectorQuantizer};
use crate::error::{TokenizerError, TokenizerResult};
use crate::SignalTokenizer;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::thread_rng;
#[derive(Debug, Clone)]
pub struct VQVAETokenizer {
encoder: Array2<f32>,
quantizer: VectorQuantizer,
decoder: Array2<f32>,
input_dim: usize,
}
impl VQVAETokenizer {
pub fn new(input_dim: usize, config: VQConfig) -> Self {
let mut rng = thread_rng();
let enc_scale = (2.0 / (input_dim + config.embed_dim) as f32).sqrt();
let encoder = Array2::from_shape_fn((input_dim, config.embed_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * enc_scale
});
let dec_scale = (2.0 / (config.embed_dim + input_dim) as f32).sqrt();
let decoder = Array2::from_shape_fn((config.embed_dim, input_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * dec_scale
});
let quantizer = VectorQuantizer::new(config);
Self {
encoder,
quantizer,
decoder,
input_dim,
}
}
pub fn encode_quantized(&self, signal: &Array1<f32>) -> TokenizerResult<(usize, Array1<f32>)> {
if signal.len() != self.input_dim {
return Err(TokenizerError::dim_mismatch(
self.input_dim,
signal.len(),
"dimension validation",
));
}
let latent = signal.dot(&self.encoder);
self.quantizer.quantize(&latent)
}
pub fn decode_quantized(&self, quantized: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
if quantized.len() != self.quantizer.embed_dim() {
return Err(TokenizerError::dim_mismatch(
self.quantizer.embed_dim(),
quantized.len(),
"dimension validation",
));
}
Ok(quantized.dot(&self.decoder))
}
pub fn decode_from_index(&self, idx: usize) -> TokenizerResult<Array1<f32>> {
let quantized = self.quantizer.get_codebook_entry(idx)?;
self.decode_quantized(&quantized)
}
pub fn quantizer(&self) -> &VectorQuantizer {
&self.quantizer
}
pub fn quantizer_mut(&mut self) -> &mut VectorQuantizer {
&mut self.quantizer
}
pub fn encoder(&self) -> &Array2<f32> {
&self.encoder
}
pub fn decoder(&self) -> &Array2<f32> {
&self.decoder
}
pub fn set_encoder(&mut self, weights: Array2<f32>) -> TokenizerResult<()> {
if weights.shape() != [self.input_dim, self.quantizer.embed_dim()] {
return Err(TokenizerError::dim_mismatch(
self.input_dim * self.quantizer.embed_dim(),
weights.len(),
"dimension validation",
));
}
self.encoder = weights;
Ok(())
}
pub fn set_decoder(&mut self, weights: Array2<f32>) -> TokenizerResult<()> {
if weights.shape() != [self.quantizer.embed_dim(), self.input_dim] {
return Err(TokenizerError::dim_mismatch(
self.quantizer.embed_dim() * self.input_dim,
weights.len(),
"dimension validation",
));
}
self.decoder = weights;
Ok(())
}
}
impl SignalTokenizer for VQVAETokenizer {
fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let (idx, _) = self.encode_quantized(signal)?;
Ok(Array1::from_elem(1, idx as f32))
}
fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
if tokens.len() != 1 {
return Err(TokenizerError::dim_mismatch(
1,
tokens.len(),
"dimension validation",
));
}
let idx = tokens[0].round() as usize;
self.decode_from_index(idx)
}
fn embed_dim(&self) -> usize {
self.quantizer.embed_dim()
}
fn vocab_size(&self) -> usize {
self.quantizer.codebook_size()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vqvae_tokenizer() {
let config = VQConfig {
codebook_size: 16,
embed_dim: 8,
..Default::default()
};
let tokenizer = VQVAETokenizer::new(32, config);
let signal = Array1::from_vec((0..32).map(|i| (i as f32 * 0.1).sin()).collect());
let encoded = tokenizer.encode(&signal).unwrap();
assert_eq!(encoded.len(), 1);
let decoded = tokenizer.decode(&encoded).unwrap();
assert_eq!(decoded.len(), 32);
}
#[test]
fn test_encode_decode_roundtrip() {
let config = VQConfig {
codebook_size: 32,
embed_dim: 16,
..Default::default()
};
let tokenizer = VQVAETokenizer::new(64, config);
let signal = Array1::from_vec((0..64).map(|i| (i as f32 * 0.05).cos()).collect());
let (idx, _quantized) = tokenizer.encode_quantized(&signal).unwrap();
let decoded = tokenizer.decode_from_index(idx).unwrap();
assert_eq!(decoded.len(), signal.len());
}
}