pub mod causal_conv;
pub mod config;
pub mod convnext;
pub mod snake_beta;
pub mod transformer;
use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use crate::audio::{
decoder::v2::TokenizerV2Decoder,
encoder::v2::{TokenizerV2Encoder, TokenizerV2EncoderOutput},
tokenizer::v2::config::TokenizerV2Config,
};
#[derive(Debug)]
pub struct TokenizerV2 {
pub encoder: Option<TokenizerV2Encoder>,
pub decoder: TokenizerV2Decoder,
pub config: TokenizerV2Config,
}
impl TokenizerV2 {
pub fn new(config: TokenizerV2Config, use_flash_attn: bool, vb: VarBuilder) -> Result<Self> {
let decoder =
TokenizerV2Decoder::new(&config.decoder_config, use_flash_attn, vb.pp("decoder"))?;
Ok(Self {
encoder: None,
decoder,
config,
})
}
pub fn new_full(
config: TokenizerV2Config,
use_flash_attn: bool,
encoder_vb: VarBuilder,
decoder_vb: VarBuilder,
) -> Result<Self> {
let encoder = TokenizerV2Encoder::for_v2(config.encoder_valid_num_quantizers, encoder_vb)?;
let decoder = TokenizerV2Decoder::new(&config.decoder_config, use_flash_attn, decoder_vb)?;
Ok(Self {
encoder: Some(encoder),
decoder,
config,
})
}
pub fn with_encoder(mut self, encoder_vb: VarBuilder) -> Result<Self> {
let encoder =
TokenizerV2Encoder::for_v2(self.config.encoder_valid_num_quantizers, encoder_vb)?;
self.encoder = Some(encoder);
Ok(self)
}
pub fn load(config: TokenizerV2Config, use_flash_attn: bool, vb: VarBuilder) -> Result<Self> {
Self::new(config, use_flash_attn, vb)
}
pub fn has_encoder(&self) -> bool {
self.encoder.is_some()
}
pub fn config(&self) -> &TokenizerV2Config {
&self.config
}
pub fn input_sample_rate(&self) -> usize {
self.config.input_sample_rate
}
pub fn output_sample_rate(&self) -> usize {
self.config.output_sample_rate
}
pub fn encode_downsample_rate(&self) -> usize {
self.config.encode_downsample_rate
}
pub fn decode_upsample_rate(&self) -> usize {
self.config.decode_upsample_rate
}
pub fn encoder_valid_num_quantizers(&self) -> usize {
self.config.encoder_valid_num_quantizers
}
pub fn encode(&mut self, audio: &Tensor) -> Result<Tensor> {
match &mut self.encoder {
Some(encoder) => encoder.encode(audio),
None => candle_core::bail!(
"Encoder not loaded. Use new_full() or with_encoder() to load the encoder."
),
}
}
pub fn encode_with_mask(
&mut self,
audio: &Tensor,
padding_mask: &Tensor,
) -> Result<TokenizerV2EncoderOutput> {
match &mut self.encoder {
Some(encoder) => {
let codes = encoder.encode_with_mask(
audio,
padding_mask,
self.config.encode_downsample_rate,
)?;
Ok(TokenizerV2EncoderOutput::new(codes))
}
None => candle_core::bail!(
"Encoder not loaded. Use new_full() or with_encoder() to load the encoder."
),
}
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let codes = codes.transpose(1, 2)?;
self.decoder.forward(&codes)
}
pub fn chunked_decode(
&self,
codes: &Tensor,
chunk_size: usize,
left_context_size: usize,
) -> Result<Tensor> {
let codes = codes.transpose(1, 2)?;
self.decoder
.chunked_decode(&codes, chunk_size, left_context_size)
}
pub fn reset_encoder_state(&mut self) {
if let Some(encoder) = &mut self.encoder {
encoder.reset_state();
}
}
}