use crate::error::{TokenizerError, TokenizerResult};
use crate::vqvae::{ResidualVQ, VQConfig};
use candle_core::{Device, Module, Result as CandleResult, Tensor};
use candle_nn::{
conv1d, conv_transpose1d, Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig,
VarBuilder,
};
use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuralCodecConfig {
pub input_channels: usize,
pub hidden_channels: usize,
pub num_residual_blocks: usize,
pub strides: Vec<usize>,
pub dilations: Vec<usize>,
pub codebook_size: usize,
pub embed_dim: usize,
pub num_rvq_stages: usize,
pub causal: bool,
}
impl Default for NeuralCodecConfig {
fn default() -> Self {
Self {
input_channels: 1,
hidden_channels: 128,
num_residual_blocks: 2,
strides: vec![2, 4, 5, 8], dilations: vec![1, 3, 9],
codebook_size: 1024,
embed_dim: 256,
num_rvq_stages: 8,
causal: false,
}
}
}
pub struct CausalConv1d {
conv: Conv1d,
padding: usize,
}
impl CausalConv1d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> CandleResult<Self> {
let padding = (kernel_size - 1) * dilation;
let config = Conv1dConfig {
padding,
dilation,
..Default::default()
};
let conv = conv1d(in_channels, out_channels, kernel_size, config, vb)?;
Ok(Self { conv, padding })
}
pub fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
let output = self.conv.forward(x)?;
if self.padding > 0 {
let seq_len = output.dim(2)?;
output.narrow(2, 0, seq_len - self.padding)
} else {
Ok(output)
}
}
}
pub struct ResidualUnit {
conv1: Conv1d,
conv2: Conv1d,
}
impl ResidualUnit {
pub fn new(
channels: usize,
dilation: usize,
kernel_size: usize,
vb: VarBuilder,
) -> CandleResult<Self> {
let config1 = Conv1dConfig {
padding: dilation * (kernel_size - 1) / 2,
dilation,
..Default::default()
};
let config2 = Conv1dConfig {
padding: (kernel_size - 1) / 2,
..Default::default()
};
let conv1 = conv1d(channels, channels, kernel_size, config1, vb.pp("conv1"))?;
let conv2 = conv1d(channels, channels, kernel_size, config2, vb.pp("conv2"))?;
Ok(Self { conv1, conv2 })
}
pub fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
let residual = x.clone();
let out = self.conv1.forward(x)?;
let out = out.elu(1.0)?;
let out = self.conv2.forward(&out)?;
let out = (out + residual)?;
out.elu(1.0)
}
}
pub struct ConvEncoder {
init_conv: Conv1d,
down_layers: Vec<Conv1d>,
residual_blocks: Vec<Vec<ResidualUnit>>,
final_conv: Conv1d,
}
impl ConvEncoder {
pub fn new(config: &NeuralCodecConfig, vb: VarBuilder) -> CandleResult<Self> {
let mut current_channels = config.input_channels;
let init_conv = conv1d(
current_channels,
config.hidden_channels,
7,
Conv1dConfig {
padding: 3,
..Default::default()
},
vb.pp("init"),
)?;
current_channels = config.hidden_channels;
let mut down_layers = Vec::new();
let mut residual_blocks = Vec::new();
for (i, &stride) in config.strides.iter().enumerate() {
let out_channels = config.hidden_channels * 2usize.pow((i + 1) as u32);
let down = conv1d(
current_channels,
out_channels,
2 * stride,
Conv1dConfig {
stride,
padding: stride / 2,
..Default::default()
},
vb.pp(format!("down_{}", i)),
)?;
down_layers.push(down);
let mut blocks = Vec::new();
for (j, &dilation) in config.dilations.iter().enumerate() {
let block = ResidualUnit::new(
out_channels,
dilation,
3,
vb.pp(format!("res_{}_{}", i, j)),
)?;
blocks.push(block);
}
residual_blocks.push(blocks);
current_channels = out_channels;
}
let final_conv = conv1d(
current_channels,
config.embed_dim,
3,
Conv1dConfig {
padding: 1,
..Default::default()
},
vb.pp("final"),
)?;
Ok(Self {
init_conv,
down_layers,
residual_blocks,
final_conv,
})
}
pub fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
let mut out = self.init_conv.forward(x)?;
out = out.elu(1.0)?;
for (down_layer, res_blocks) in self.down_layers.iter().zip(self.residual_blocks.iter()) {
out = down_layer.forward(&out)?;
out = out.elu(1.0)?;
for res_block in res_blocks {
out = res_block.forward(&out)?;
}
}
self.final_conv.forward(&out)
}
}
pub struct ConvDecoder {
init_conv: Conv1d,
up_layers: Vec<ConvTranspose1d>,
residual_blocks: Vec<Vec<ResidualUnit>>,
final_conv: Conv1d,
}
impl ConvDecoder {
pub fn new(config: &NeuralCodecConfig, vb: VarBuilder) -> CandleResult<Self> {
let last_layer_channels = config.hidden_channels * 2usize.pow(config.strides.len() as u32);
let init_conv = conv1d(
config.embed_dim,
last_layer_channels,
3,
Conv1dConfig {
padding: 1,
..Default::default()
},
vb.pp("init"),
)?;
let mut up_layers = Vec::new();
let mut residual_blocks = Vec::new();
let mut current_channels = last_layer_channels;
for (i, &stride) in config.strides.iter().enumerate().rev() {
let layer_idx = config.strides.len() - 1 - i;
let out_channels = if layer_idx == 0 {
config.hidden_channels
} else {
config.hidden_channels * 2usize.pow(layer_idx as u32)
};
let mut blocks = Vec::new();
for (j, &dilation) in config.dilations.iter().enumerate() {
let block = ResidualUnit::new(
current_channels,
dilation,
3,
vb.pp(format!("res_{}_{}", i, j)),
)?;
blocks.push(block);
}
residual_blocks.push(blocks);
let up = conv_transpose1d(
current_channels,
out_channels,
2 * stride,
ConvTranspose1dConfig {
stride,
padding: stride / 2,
..Default::default()
},
vb.pp(format!("up_{}", i)),
)?;
up_layers.push(up);
current_channels = out_channels;
}
let final_conv = conv1d(
current_channels,
config.input_channels,
7,
Conv1dConfig {
padding: 3,
..Default::default()
},
vb.pp("final"),
)?;
Ok(Self {
init_conv,
up_layers,
residual_blocks,
final_conv,
})
}
pub fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
let mut out = self.init_conv.forward(x)?;
out = out.elu(1.0)?;
for (res_blocks, up_layer) in self.residual_blocks.iter().zip(self.up_layers.iter()) {
for res_block in res_blocks {
out = res_block.forward(&out)?;
}
out = up_layer.forward(&out)?;
out = out.elu(1.0)?;
}
self.final_conv.forward(&out)
}
}
pub struct NeuralCodec {
config: NeuralCodecConfig,
encoder: ConvEncoder,
decoder: ConvDecoder,
rvq: ResidualVQ,
device: Device,
}
impl NeuralCodec {
pub fn new(config: NeuralCodecConfig, vb: VarBuilder) -> CandleResult<Self> {
let device = vb.device().clone();
let encoder = ConvEncoder::new(&config, vb.pp("encoder"))?;
let decoder = ConvDecoder::new(&config, vb.pp("decoder"))?;
let vq_config = VQConfig {
codebook_size: config.codebook_size,
embed_dim: config.embed_dim,
commitment_beta: 0.25,
ema_decay: 0.99,
epsilon: 1e-5,
use_ema: true,
};
let rvq = ResidualVQ::new(config.num_rvq_stages, vq_config);
Ok(Self {
config,
encoder,
decoder,
rvq,
device,
})
}
pub fn encode(&self, signal: &[f32]) -> TokenizerResult<Vec<Vec<usize>>> {
let tensor = Tensor::from_slice(signal, (1, 1, signal.len()), &self.device)
.map_err(|e| TokenizerError::encoding("neural_codec", e.to_string()))?;
let latent = self
.encoder
.forward(&tensor)
.map_err(|e| TokenizerError::encoding("neural_codec_encoder", e.to_string()))?;
let latent_data = latent
.squeeze(0)
.map_err(|e| TokenizerError::encoding("neural_codec_squeeze", e.to_string()))?
.to_vec2::<f32>()
.map_err(|e| TokenizerError::encoding("neural_codec_latent", e.to_string()))?;
let mut codes = Vec::new();
for time_step in 0..latent_data[0].len() {
let vector: Vec<f32> = latent_data.iter().map(|row| row[time_step]).collect();
let vector_array = Array1::from_vec(vector);
let (code_seq, _) = self.rvq.encode(&vector_array)?;
codes.push(code_seq);
}
Ok(codes)
}
pub fn decode(&self, codes: &[Vec<usize>]) -> TokenizerResult<Vec<f32>> {
if codes.is_empty() {
return Err(TokenizerError::decoding(
"neural_codec",
"Empty code sequence".to_string(),
));
}
let time_steps = codes.len();
let embed_dim = self.config.embed_dim;
let mut latent_data = vec![vec![0.0f32; time_steps]; embed_dim];
for (t, code_seq) in codes.iter().enumerate() {
let vector = self.rvq.decode(code_seq)?;
for (d, &val) in vector.iter().enumerate() {
if d < embed_dim {
latent_data[d][t] = val;
}
}
}
let flat_data: Vec<f32> = latent_data.iter().flatten().copied().collect();
let latent = Tensor::from_slice(&flat_data, (1, embed_dim, time_steps), &self.device)
.map_err(|e| TokenizerError::decoding("neural_codec_latent", e.to_string()))?;
let output = self
.decoder
.forward(&latent)
.map_err(|e| TokenizerError::decoding("neural_codec_decoder", e.to_string()))?;
let signal = output
.squeeze(0)
.map_err(|e| TokenizerError::decoding("neural_codec_squeeze1", e.to_string()))?
.squeeze(0)
.map_err(|e| TokenizerError::decoding("neural_codec_squeeze2", e.to_string()))?
.to_vec1::<f32>()
.map_err(|e| TokenizerError::decoding("neural_codec_output", e.to_string()))?;
Ok(signal)
}
pub fn compression_ratio(&self) -> f32 {
let total_stride: usize = self.config.strides.iter().product();
total_stride as f32
}
pub fn bitrate(&self, sample_rate: f32) -> f32 {
let compressed_rate = sample_rate / self.compression_ratio();
let bits_per_frame =
self.config.num_rvq_stages as f32 * (self.config.codebook_size as f32).log2();
compressed_rate * bits_per_frame
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::DType;
use candle_nn::VarMap;
#[test]
fn test_neural_codec_config_default() {
let config = NeuralCodecConfig::default();
assert_eq!(config.input_channels, 1);
assert_eq!(config.hidden_channels, 128);
assert_eq!(config.num_rvq_stages, 8);
assert!(!config.causal);
}
#[test]
fn test_compression_ratio() {
let config = NeuralCodecConfig {
input_channels: 1,
hidden_channels: 8, num_residual_blocks: 1, strides: vec![2, 4, 5, 8], dilations: vec![1], codebook_size: 64, embed_dim: 8, num_rvq_stages: 2, causal: false,
};
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let codec = NeuralCodec::new(config.clone(), vb).unwrap();
let ratio = codec.compression_ratio();
let expected: usize = config.strides.iter().product();
assert_eq!(ratio, expected as f32);
}
#[test]
fn test_bitrate_calculation() {
let config = NeuralCodecConfig {
input_channels: 1,
hidden_channels: 8, num_residual_blocks: 1, strides: vec![2, 4, 5, 8], dilations: vec![1], codebook_size: 1024, embed_dim: 8, num_rvq_stages: 8, causal: false,
};
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let codec = NeuralCodec::new(config, vb).unwrap();
let bitrate = codec.bitrate(16000.0);
assert!((bitrate - 4000.0).abs() < 1.0);
}
#[test]
#[ignore] fn test_encode_decode_roundtrip() {
let config = NeuralCodecConfig::default();
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let codec = NeuralCodec::new(config, vb).unwrap();
let signal: Vec<f32> = (0..1024).map(|i| (i as f32 * 0.01).sin()).collect();
let codes = codec.encode(&signal).unwrap();
let reconstructed = codec.decode(&codes).unwrap();
assert_eq!(reconstructed.len(), signal.len());
}
}