any-tts 0.1.1

A Rust TTS library with Candle backends and runtime adapters for modern open TTS models
Documentation
use candle_core::{Device, IndexOp, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};

use crate::layers::conv::{Conv1d, ConvTranspose1d};

use super::config::OmniVoiceAudioTokenizerConfig;

fn snake1d(x: &Tensor, alpha: &Tensor) -> Result<Tensor> {
    let sin_sq = alpha.broadcast_mul(x)?.sin()?.sqr()?;
    let inv_alpha = (alpha + 1e-9)?.recip()?;
    x.add(&sin_sq.broadcast_mul(&inv_alpha)?)
}

fn trim_residual_input(input: &Tensor, target: &Tensor) -> Result<Tensor> {
    let input_len = input.dim(2)?;
    let target_len = target.dim(2)?;
    if input_len <= target_len {
        return Ok(input.clone());
    }
    let padding = (input_len - target_len) / 2;
    input.narrow(2, padding, target_len)
}

struct OmniVoiceVectorQuantizer {
    codebook: Tensor,
    project_out: Linear,
}

impl OmniVoiceVectorQuantizer {
    fn load(config: &OmniVoiceAudioTokenizerConfig, vb: VarBuilder) -> Result<Self> {
        let codebook = vb
            .pp("codebook")
            .get((config.codebook_size, config.codebook_dim), "embed")?;
        let project_out = candle_nn::linear(
            config.codebook_dim,
            config.acoustic_model_config.decoder_hidden_size,
            vb.pp("project_out"),
        )?;
        Ok(Self {
            codebook,
            project_out,
        })
    }

    fn decode(&self, indices: &Tensor) -> Result<Tensor> {
        let (batch, seq_len) = indices.dims2()?;
        let flat = indices.flatten_all()?;
        let quantized = self.codebook.index_select(&flat, 0)?;
        let quantized = quantized.reshape((batch, seq_len, self.codebook.dim(1)?))?;
        let projected = self.project_out.forward(&quantized)?;
        projected.transpose(1, 2)
    }
}

struct OmniVoiceResidualVectorQuantizer {
    quantizers: Vec<OmniVoiceVectorQuantizer>,
}

impl OmniVoiceResidualVectorQuantizer {
    fn load(
        config: &OmniVoiceAudioTokenizerConfig,
        num_quantizers: usize,
        vb: VarBuilder,
    ) -> Result<Self> {
        let mut quantizers = Vec::with_capacity(num_quantizers);
        for index in 0..num_quantizers {
            quantizers.push(OmniVoiceVectorQuantizer::load(
                config,
                vb.pp(format!("quantizers.{}", index)),
            )?);
        }
        Ok(Self { quantizers })
    }

    fn decode(&self, codes: &Tensor) -> Result<Tensor> {
        let mut summed: Option<Tensor> = None;
        for (index, quantizer) in self.quantizers.iter().enumerate() {
            let indices = codes.i((.., index, ..))?;
            let decoded = quantizer.decode(&indices)?;
            summed = Some(match summed {
                Some(acc) => acc.add(&decoded)?,
                None => decoded,
            });
        }
        summed.ok_or_else(|| {
            candle_core::Error::Msg("OmniVoice audio decoder has no quantizers".to_string())
        })
    }
}

struct DacResidualUnit {
    snake1_alpha: Tensor,
    conv1: Conv1d,
    snake2_alpha: Tensor,
    conv2: Conv1d,
}

impl DacResidualUnit {
    fn load(channels: usize, dilation: usize, vb: VarBuilder) -> Result<Self> {
        let pad = 3 * dilation;
        let snake1_alpha = vb.pp("snake1").get((1, channels, 1), "alpha")?;
        let conv1 = Conv1d::load(
            channels,
            channels,
            7,
            1,
            pad,
            dilation,
            1,
            true,
            vb.pp("conv1"),
        )?;
        let snake2_alpha = vb.pp("snake2").get((1, channels, 1), "alpha")?;
        let conv2 = Conv1d::load(channels, channels, 1, 1, 0, 1, 1, true, vb.pp("conv2"))?;
        Ok(Self {
            snake1_alpha,
            conv1,
            snake2_alpha,
            conv2,
        })
    }

    fn forward(&self, hidden: &Tensor) -> Result<Tensor> {
        let h = self.conv1.forward(&snake1d(hidden, &self.snake1_alpha)?)?;
        let h = self.conv2.forward(&snake1d(&h, &self.snake2_alpha)?)?;
        trim_residual_input(hidden, &h)?.add(&h)
    }
}

struct DacDecoderBlock {
    snake1_alpha: Tensor,
    conv_t1: ConvTranspose1d,
    res_unit1: DacResidualUnit,
    res_unit2: DacResidualUnit,
    res_unit3: DacResidualUnit,
}

impl DacDecoderBlock {
    fn load(
        config: &OmniVoiceAudioTokenizerConfig,
        stride: usize,
        stride_index: usize,
        vb: VarBuilder,
    ) -> Result<Self> {
        let input_dim = config.acoustic_model_config.decoder_hidden_size / (1usize << stride_index);
        let output_dim =
            config.acoustic_model_config.decoder_hidden_size / (1usize << (stride_index + 1));
        let snake1_alpha = vb.pp("snake1").get((1, input_dim, 1), "alpha")?;
        let conv_t1 = ConvTranspose1d::load(
            input_dim,
            output_dim,
            2 * stride,
            stride,
            stride.div_ceil(2),
            stride % 2,
            1,
            true,
            vb.pp("conv_t1"),
        )?;
        let res_unit1 = DacResidualUnit::load(output_dim, 1, vb.pp("res_unit1"))?;
        let res_unit2 = DacResidualUnit::load(output_dim, 3, vb.pp("res_unit2"))?;
        let res_unit3 = DacResidualUnit::load(output_dim, 9, vb.pp("res_unit3"))?;
        Ok(Self {
            snake1_alpha,
            conv_t1,
            res_unit1,
            res_unit2,
            res_unit3,
        })
    }

    fn forward(&self, hidden: &Tensor) -> Result<Tensor> {
        let h = snake1d(hidden, &self.snake1_alpha)?;
        let h = self.conv_t1.forward(&h)?;
        let h = self.res_unit1.forward(&h)?;
        let h = self.res_unit2.forward(&h)?;
        self.res_unit3.forward(&h)
    }
}

struct DacDecoder {
    conv1: Conv1d,
    blocks: Vec<DacDecoderBlock>,
    snake1_alpha: Tensor,
    conv2: Conv1d,
}

impl DacDecoder {
    fn load(config: &OmniVoiceAudioTokenizerConfig, vb: VarBuilder) -> Result<Self> {
        let conv1 = Conv1d::load(
            config.acoustic_model_config.hidden_size,
            config.acoustic_model_config.decoder_hidden_size,
            7,
            1,
            3,
            1,
            1,
            true,
            vb.pp("conv1"),
        )?;
        let mut blocks = Vec::with_capacity(config.acoustic_model_config.upsampling_ratios.len());
        for (index, stride) in config
            .acoustic_model_config
            .upsampling_ratios
            .iter()
            .copied()
            .enumerate()
        {
            blocks.push(DacDecoderBlock::load(
                config,
                stride,
                index,
                vb.pp(format!("block.{}", index)),
            )?);
        }
        let output_dim = config.acoustic_model_config.decoder_hidden_size
            / (1usize << config.acoustic_model_config.upsampling_ratios.len());
        let snake1_alpha = vb.pp("snake1").get((1, output_dim, 1), "alpha")?;
        let conv2 = Conv1d::load(output_dim, 1, 7, 1, 3, 1, 1, true, vb.pp("conv2"))?;
        Ok(Self {
            conv1,
            blocks,
            snake1_alpha,
            conv2,
        })
    }

    fn forward(&self, hidden: &Tensor) -> Result<Tensor> {
        let mut h = self.conv1.forward(hidden)?;
        for block in &self.blocks {
            h = block.forward(&h)?;
        }
        let h = snake1d(&h, &self.snake1_alpha)?;
        self.conv2.forward(&h)?.clamp(-1f32, 1f32)
    }
}

pub struct OmniVoiceAudioTokenizerDecoder {
    quantizer: OmniVoiceResidualVectorQuantizer,
    fc2: Linear,
    acoustic_decoder: DacDecoder,
    sample_rate: u32,
}

impl OmniVoiceAudioTokenizerDecoder {
    pub fn load(
        config: &OmniVoiceAudioTokenizerConfig,
        num_quantizers: usize,
        vb: VarBuilder,
        _device: &Device,
    ) -> Result<Self> {
        let quantizer =
            OmniVoiceResidualVectorQuantizer::load(config, num_quantizers, vb.pp("quantizer"))?;
        let fc2 = candle_nn::linear(
            config.acoustic_model_config.decoder_hidden_size,
            config.acoustic_model_config.hidden_size,
            vb.pp("fc2"),
        )?;
        let acoustic_decoder = DacDecoder::load(config, vb.pp("acoustic_decoder"))?;
        Ok(Self {
            quantizer,
            fc2,
            acoustic_decoder,
            sample_rate: config.sample_rate,
        })
    }

    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
        let quantized = self.quantizer.decode(codes)?;
        let quantized = self
            .fc2
            .forward(&quantized.transpose(1, 2)?)?
            .transpose(1, 2)?;
        self.acoustic_decoder.forward(&quantized)
    }

    pub fn sample_rate(&self) -> u32 {
        self.sample_rate
    }
}