any-tts 0.1.0

A Rust TTS library with Candle backends and runtime adapters for modern open TTS models
Documentation
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use tracing::info;

use crate::config::ModelFiles;
use crate::error::TtsError;
use crate::tokenizer::TextTokenizer;

use super::config::{VibeVoiceConfig, VibeVoicePreprocessorConfig};
use super::diffusion::VibeVoiceDiffusionHead;
use super::processor::{VibeVoiceProcessor, VibeVoiceTokenizerSpec};
use super::runtime::{SpeechConnector, VibeVoiceLanguageModel};
use super::speech_tokenizer::{VibeVoiceAcousticTokenizer, VibeVoiceSemanticTokenizer};

pub(super) struct LoadedVibeVoiceComponents {
    pub language_model: VibeVoiceLanguageModel,
    pub acoustic_tokenizer: VibeVoiceAcousticTokenizer,
    pub semantic_tokenizer: VibeVoiceSemanticTokenizer,
    pub acoustic_connector: SpeechConnector,
    pub semantic_connector: SpeechConnector,
    pub prediction_head: VibeVoiceDiffusionHead,
    pub speech_scaling_factor: f32,
    pub speech_bias_factor: f32,
}

pub(super) fn resolve_runtime_dtype(device: &Device, requested: DType) -> DType {
    if matches!(device, Device::Cpu) && requested == DType::BF16 {
        info!("BF16 is not supported on CPU; falling back to F32 for VibeVoice");
        return DType::F32;
    }

    if matches!(device, Device::Metal(_)) {
        info!(
            "VibeVoice uses F32 on Metal to match the Python reference path and avoid unstable first-token behavior"
        );
        return DType::F32;
    }

    requested
}

pub(super) fn load_preprocessor_config(
    files: &ModelFiles,
) -> Result<VibeVoicePreprocessorConfig, TtsError> {
    if let Some(asset) = &files.preprocessor_config {
        let bytes = asset.read_bytes()?;
        return VibeVoicePreprocessorConfig::from_bytes(bytes.as_ref());
    }
    Ok(VibeVoicePreprocessorConfig::default())
}

pub(super) fn build_processor(
    files: &ModelFiles,
    preprocessor_config: &VibeVoicePreprocessorConfig,
) -> Result<VibeVoiceProcessor, TtsError> {
    let tokenizer = TextTokenizer::from_asset(
        files
            .tokenizer
            .as_ref()
            .expect("validated by resolve_files"),
    )?;
    let tokenizer_spec = VibeVoiceTokenizerSpec::from_tokenizer(&tokenizer)?;
    Ok(VibeVoiceProcessor::new(
        tokenizer,
        tokenizer_spec,
        preprocessor_config.clone(),
    ))
}

pub(super) fn load_components(
    files: &ModelFiles,
    model_config: &VibeVoiceConfig,
    device: &Device,
    dtype: DType,
) -> Result<LoadedVibeVoiceComponents, TtsError> {
    let vb = ModelFiles::load_safetensors_vb(&files.weights, dtype, device)?;
    let model_vb = vb.pp("model");
    let language_model = VibeVoiceLanguageModel::load(
        &model_config.decoder_config,
        model_vb.pp("language_model"),
        device,
        dtype,
    )?;
    let acoustic_tokenizer = VibeVoiceAcousticTokenizer::load(
        &model_config.acoustic_tokenizer_config,
        model_vb.pp("acoustic_tokenizer"),
    )?;
    let semantic_tokenizer = VibeVoiceSemanticTokenizer::load(
        &model_config.semantic_tokenizer_config,
        model_vb.pp("semantic_tokenizer"),
    )?;
    let acoustic_connector = SpeechConnector::load(
        model_config.acoustic_vae_dim,
        model_config.decoder_config.hidden_size,
        model_vb.pp("acoustic_connector"),
    )?;
    let semantic_connector = SpeechConnector::load(
        model_config.semantic_vae_dim,
        model_config.decoder_config.hidden_size,
        model_vb.pp("semantic_connector"),
    )?;
    let prediction_head = VibeVoiceDiffusionHead::load(
        &model_config.diffusion_head_config,
        model_vb.pp("prediction_head"),
    )?;

    Ok(LoadedVibeVoiceComponents {
        language_model,
        acoustic_tokenizer,
        semantic_tokenizer,
        acoustic_connector,
        semantic_connector,
        prediction_head,
        speech_scaling_factor: load_scalar(&model_vb, "speech_scaling_factor")?,
        speech_bias_factor: load_scalar(&model_vb, "speech_bias_factor")?,
    })
}

fn load_scalar(model_vb: &VarBuilder, name: &str) -> Result<f32, TtsError> {
    let tensor = model_vb.get(1, name).or_else(|_| model_vb.get((), name))?;
    scalar_from_tensor(&tensor)
}

fn scalar_from_tensor(tensor: &Tensor) -> Result<f32, TtsError> {
    let values = tensor
        .to_dtype(DType::F32)?
        .flatten_all()?
        .to_vec1::<f32>()?;
    values
        .first()
        .copied()
        .ok_or_else(|| TtsError::ModelError("Expected scalar tensor".to_string()))
}