voxcpm-rs 0.3.0

Pure-Rust inference for VoxCPM2 on top of the Burn framework (Vulkan + CPU).
Documentation
//! Top-level AudioVAE v2 wrapper.

use crate::audiovae::decoder::CausalDecoder;
use crate::audiovae::encoder::CausalEncoder;
use crate::config::AudioVaeConfig;
use burn::module::Ignored;
use burn::prelude::*;

#[derive(Module, Debug)]
pub struct AudioVae<B: Backend> {
    pub encoder: CausalEncoder<B>,
    pub decoder: CausalDecoder<B>,
    pub config: Ignored<AudioVaeConfig>,
}

impl<B: Backend> AudioVae<B> {
    pub fn new(config: AudioVaeConfig, device: &B::Device) -> Self {
        let encoder = CausalEncoder::new(
            config.encoder_dim,
            config.latent_dim,
            &config.encoder_rates,
            config.depthwise,
            device,
        );
        let decoder = CausalDecoder::new(
            config.latent_dim,
            config.decoder_dim,
            &config.decoder_rates,
            config.depthwise,
            config.use_noise_block,
            config.sr_bin_boundaries.clone(),
            &config.cond_type,
            device,
        );
        Self {
            encoder,
            decoder,
            config: Ignored(config),
        }
    }

    /// Encode raw waveform `[B, 1, T]` → latent `[B, C, T']` (µ only).
    pub fn encode(&self, audio: Tensor<B, 3>) -> Tensor<B, 3> {
        let hop: usize = self.config.0.encoder_rates.iter().product();
        let t = audio.dims()[2];
        let pad = (t + hop - 1) / hop * hop - t;
        let audio = if pad > 0 {
            audio.pad((0, pad, 0, 0), 0.0)
        } else {
            audio
        };
        self.encoder.forward_mu(audio)
    }

    /// Decode latent `[B, C, T']` → waveform `[B, 1, T]`.
    pub fn decode(&self, z: Tensor<B, 3>) -> Tensor<B, 3> {
        let target_sr = self.config.0.out_sample_rate as i32;
        self.decoder.forward(z, target_sr)
    }

    pub fn sample_rate(&self) -> usize {
        self.config.0.sample_rate
    }
    pub fn out_sample_rate(&self) -> usize {
        self.config.0.out_sample_rate
    }
}