use candle_core::{DType, IndexOp, Result, Tensor};
use candle_nn::VarBuilder;
use crate::audio::encoder::{
quantizer::EncoderSplitRVQ,
seanet::{Downsample, EncoderTransformer, SeaNetEncoder},
};
#[derive(Debug, Clone)]
pub struct TokenizerV2Encoder {
seanet: SeaNetEncoder,
transformer: EncoderTransformer,
downsample: Downsample,
quantizer: EncoderSplitRVQ,
valid_num_quantizers: usize,
}
impl TokenizerV2Encoder {
pub fn new(
valid_num_quantizers: usize,
device: &candle_core::Device,
dtype: DType,
vb: VarBuilder,
) -> Result<Self> {
let seanet = SeaNetEncoder::new(vb.pp("encoder"))?;
let transformer =
EncoderTransformer::new(512, 8, 2048, 8, device, dtype, vb.pp("encoder_transformer"))?;
let downsample = Downsample::new(512, 2, vb.pp("downsample"))?;
let quantizer = EncoderSplitRVQ::new(
256,
512,
512,
1,
31, 2048,
vb.pp("quantizer"),
)?;
Ok(Self {
seanet,
transformer,
downsample,
quantizer,
valid_num_quantizers,
})
}
pub fn for_v2(valid_num_quantizers: usize, vb: VarBuilder) -> Result<Self> {
let device = vb.device().clone();
let dtype = vb.dtype();
Self::new(valid_num_quantizers, &device, dtype, vb)
}
pub fn valid_num_quantizers(&self) -> usize {
self.valid_num_quantizers
}
pub fn encode(&mut self, audio: &Tensor) -> Result<Tensor> {
let audio = if audio.dims().len() == 2 {
audio.unsqueeze(1)?
} else {
audio.clone()
};
let h = self.seanet.forward(&audio)?;
let h = h.transpose(1, 2)?;
let h = self.transformer.forward(&h)?;
let h = h.transpose(1, 2)?;
let h = self.downsample.forward(&h)?;
let codes = self.quantizer.encode(&h)?;
let n_q = codes.dim(1)?;
if n_q > self.valid_num_quantizers {
codes.narrow(1, 0, self.valid_num_quantizers)
} else {
Ok(codes)
}
}
pub fn encode_with_mask(
&mut self,
audio: &Tensor,
padding_mask: &Tensor,
downsample_rate: usize,
) -> Result<Vec<Tensor>> {
let codes = self.encode(audio)?;
let batch_size = audio.dim(0)?;
let mut result = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let mask = padding_mask.i(b)?;
let code = codes.i(b)?;
let mask_f64 = mask.to_dtype(DType::F64)?;
let valid_samples = mask_f64.sum_all()?.to_scalar::<f64>()? as usize;
let valid_frames = valid_samples / downsample_rate;
let seq_len = code.dim(1)?;
let trimmed = if valid_frames < seq_len && valid_frames > 0 {
code.narrow(1, 0, valid_frames)?
} else {
code.clone()
};
result.push(trimmed.transpose(0, 1)?);
}
Ok(result)
}
pub fn reset_state(&mut self) {
}
}
#[derive(Debug, Clone)]
pub struct TokenizerV2EncoderOutput {
pub audio_codes: Vec<Tensor>,
}
impl TokenizerV2EncoderOutput {
pub fn new(audio_codes: Vec<Tensor>) -> Self {
Self { audio_codes }
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
#[test]
fn test_encoder_output_shape() {
let codes = vec![
Tensor::zeros((25, 16), DType::I64, &Device::Cpu).unwrap(),
Tensor::zeros((30, 16), DType::I64, &Device::Cpu).unwrap(),
];
let output = TokenizerV2EncoderOutput::new(codes);
assert_eq!(output.audio_codes.len(), 2);
assert_eq!(output.audio_codes[0].dims(), &[25, 16]);
assert_eq!(output.audio_codes[1].dims(), &[30, 16]);
}
}