any-tts 0.1.1

A Rust TTS library with Candle backends and runtime adapters for modern open TTS models
Documentation
use std::collections::BTreeMap;

use serde::Deserialize;

use crate::error::TtsError;

const AUDIO_SPECIAL_TOKEN_COUNT: usize = 2;

#[derive(Debug, Clone, Deserialize)]
pub struct VoxtralConfig {
    pub dim: usize,
    pub n_layers: usize,
    pub head_dim: usize,
    pub hidden_dim: usize,
    pub n_heads: usize,
    pub n_kv_heads: usize,
    pub rope_theta: f64,
    pub norm_eps: f64,
    pub vocab_size: usize,
    pub max_seq_len: usize,
    pub multimodal: MultimodalConfig,
}

impl VoxtralConfig {
    pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, TtsError> {
        Ok(serde_json::from_slice(bytes.as_ref())?)
    }
}

#[derive(Debug, Clone, Deserialize)]
pub struct MultimodalConfig {
    pub bos_token_id: u32,
    pub audio_model_args: MultimodalAudioModelArgs,
    pub audio_tokenizer_args: AudioTokenizerArgs,
}

#[derive(Debug, Clone, Deserialize)]
pub struct AudioEncodingArgs {
    pub sampling_rate: u32,
    #[serde(rename = "frame_rate")]
    _frame_rate: f64,
    #[serde(rename = "num_codebooks")]
    _num_codebooks: usize,
}

#[derive(Debug, Clone, Deserialize)]
pub struct AcousticTransformerArgs {
    pub input_dim: usize,
    pub dim: usize,
    pub n_layers: usize,
    pub head_dim: usize,
    pub hidden_dim: usize,
    pub n_heads: usize,
    pub n_kv_heads: usize,
    #[serde(default)]
    pub use_biases: bool,
    #[serde(default = "default_norm_eps")]
    pub norm_eps: f64,
    #[serde(default = "default_sigma", rename = "sigma")]
    _sigma: f64,
    #[serde(default, rename = "sigma_max")]
    _sigma_max: Option<f64>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct MultimodalAudioModelArgs {
    pub semantic_codebook_size: usize,
    pub acoustic_codebook_size: usize,
    pub n_acoustic_codebook: usize,
    pub audio_encoding_args: AudioEncodingArgs,
    pub audio_token_id: u32,
    pub begin_audio_token_id: u32,
    #[serde(default, rename = "input_embedding_concat_type")]
    _input_embedding_concat_type: Option<String>,
    pub acoustic_transformer_args: AcousticTransformerArgs,
    #[serde(default, rename = "p_uncond")]
    _p_uncond: Option<f64>,
    #[serde(default, rename = "text_feature_bugged")]
    _text_feature_bugged: Option<bool>,
    #[serde(default, rename = "condition_dropped_token_id")]
    _condition_dropped_token_id: Option<u32>,
}

impl MultimodalAudioModelArgs {
    pub fn codebook_sizes(&self) -> Vec<usize> {
        let mut sizes = Vec::with_capacity(self.n_acoustic_codebook + 1);
        sizes.push(self.semantic_codebook_size);
        sizes.extend(std::iter::repeat_n(
            self.acoustic_codebook_size,
            self.n_acoustic_codebook,
        ));
        sizes
    }

    pub fn get_codebook_sizes(
        &self,
        pad_to_multiple: Option<usize>,
        include_special_tokens: bool,
    ) -> Vec<usize> {
        self.codebook_sizes()
            .into_iter()
            .map(|mut size| {
                if include_special_tokens {
                    size += AUDIO_SPECIAL_TOKEN_COUNT;
                }
                if let Some(multiple) = pad_to_multiple {
                    size = round_up_to_multiple(size, multiple);
                }
                size
            })
            .collect()
    }
}

#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct AudioTokenizerArgs {
    pub channels: usize,
    pub sampling_rate: u32,
    pub pretransform_patch_size: usize,
    pub patch_proj_kernel_size: usize,
    pub semantic_codebook_size: usize,
    pub semantic_dim: usize,
    pub acoustic_codebook_size: usize,
    pub acoustic_dim: usize,
    pub conv_weight_norm: bool,
    pub causal: bool,
    pub attn_sliding_window_size: usize,
    pub half_attn_window_upon_downsampling: bool,
    pub dim: usize,
    pub hidden_dim: usize,
    pub head_dim: usize,
    pub n_heads: usize,
    pub n_kv_heads: usize,
    pub qk_norm_eps: f64,
    pub qk_norm: bool,
    pub use_biases: bool,
    pub norm_eps: f64,
    pub layer_scale: bool,
    pub layer_scale_init: Option<f64>,
    pub encoder_transformer_lengths_str: String,
    pub encoder_convs_kernels_str: String,
    pub encoder_convs_strides_str: String,
    pub decoder_transformer_lengths_str: String,
    pub decoder_convs_kernels_str: String,
    pub decoder_convs_strides_str: String,
    pub voice: BTreeMap<String, u32>,
}

impl Default for AudioTokenizerArgs {
    fn default() -> Self {
        Self {
            channels: 1,
            sampling_rate: 24_000,
            pretransform_patch_size: 240,
            patch_proj_kernel_size: 7,
            semantic_codebook_size: 8192,
            semantic_dim: 256,
            acoustic_codebook_size: 21,
            acoustic_dim: 36,
            conv_weight_norm: true,
            causal: true,
            attn_sliding_window_size: 16,
            half_attn_window_upon_downsampling: true,
            dim: 1024,
            hidden_dim: 4096,
            head_dim: 128,
            n_heads: 8,
            n_kv_heads: 8,
            qk_norm_eps: 1e-6,
            qk_norm: true,
            use_biases: false,
            norm_eps: 1e-2,
            layer_scale: true,
            layer_scale_init: None,
            encoder_transformer_lengths_str: "2,2,2,2".to_string(),
            encoder_convs_kernels_str: "4,4,4,3".to_string(),
            encoder_convs_strides_str: "2,2,2,1".to_string(),
            decoder_transformer_lengths_str: "2,2,2,2".to_string(),
            decoder_convs_kernels_str: "3,4,4,4".to_string(),
            decoder_convs_strides_str: "1,2,2,2".to_string(),
            voice: BTreeMap::new(),
        }
    }
}

impl AudioTokenizerArgs {
    pub fn encoder_convs_strides(&self) -> Result<Vec<usize>, TtsError> {
        parse_csv_usize(&self.encoder_convs_strides_str)
    }

    pub fn decoder_transformer_lengths(&self) -> Result<Vec<usize>, TtsError> {
        parse_csv_usize(&self.decoder_transformer_lengths_str)
    }

    pub fn decoder_convs_kernels(&self) -> Result<Vec<usize>, TtsError> {
        parse_csv_usize(&self.decoder_convs_kernels_str)
    }

    pub fn decoder_convs_strides(&self) -> Result<Vec<usize>, TtsError> {
        parse_csv_usize(&self.decoder_convs_strides_str)
    }

    pub fn frame_rate(&self) -> Result<f64, TtsError> {
        let scale_factor: usize = self.encoder_convs_strides()?.into_iter().product();
        Ok(self.sampling_rate as f64 / (self.pretransform_patch_size * scale_factor) as f64)
    }

    pub fn voice_names(&self) -> Vec<String> {
        let mut entries: Vec<_> = self.voice.iter().collect();
        entries.sort_by_key(|(_, index)| **index);
        entries.into_iter().map(|(name, _)| name.clone()).collect()
    }
}

fn default_norm_eps() -> f64 {
    1e-5
}

fn default_sigma() -> f64 {
    1e-5
}

fn round_up_to_multiple(value: usize, multiple: usize) -> usize {
    multiple * value.div_ceil(multiple)
}

fn parse_csv_usize(value: &str) -> Result<Vec<usize>, TtsError> {
    value
        .split(',')
        .map(|part| {
            part.parse::<usize>().map_err(|err| {
                TtsError::ConfigError(format!("Invalid Voxtral config list '{value}': {err}"))
            })
        })
        .collect()
}