any-tts 0.1.1

A Rust TTS library with Candle backends and runtime adapters for modern open TTS models
Documentation
use candle_core::{DType, Device, Result as CandleResult, Tensor};
use candle_nn::{Embedding, Linear, Module, VarBuilder};

use crate::error::TtsError;
use crate::layers::attention::GqaConfig;
use crate::layers::transformer::TransformerBlock;
use crate::models::vibevoice::generation::LayerKvCache;
use crate::tensor_utils::{precompute_rope_freqs, RmsNorm};

use super::config::VibeVoiceRealtimeConfig;

#[derive(Clone)]
pub struct RealtimeDecoderState {
    next_position: usize,
    last_hidden: Tensor,
    layer_caches: Vec<LayerKvCache>,
}

impl RealtimeDecoderState {
    pub fn new(next_position: usize, last_hidden: Tensor, layer_caches: Vec<LayerKvCache>) -> Self {
        Self {
            next_position,
            last_hidden,
            layer_caches,
        }
    }

    pub fn next_position(&self) -> usize {
        self.next_position
    }

    pub fn last_hidden(&self) -> &Tensor {
        &self.last_hidden
    }

    pub fn layer_caches(&self) -> &[LayerKvCache] {
        &self.layer_caches
    }
}

pub struct RealtimeLanguageModel {
    embed_tokens: Embedding,
    layers: Vec<TransformerBlock>,
    norm: Option<RmsNorm>,
    rope_cos: Tensor,
    rope_sin: Tensor,
}

impl RealtimeLanguageModel {
    pub fn load(
        config: &VibeVoiceRealtimeConfig,
        vb: VarBuilder,
        device: &Device,
        dtype: DType,
        layer_count: usize,
        apply_final_norm: bool,
    ) -> Result<Self, TtsError> {
        let decoder = &config.decoder_config;
        let head_dim = decoder.hidden_size / decoder.num_attention_heads;
        let gqa_config = GqaConfig::with_head_dim(
            decoder.hidden_size,
            decoder.num_attention_heads,
            decoder.num_key_value_heads,
            head_dim,
            decoder.max_position_embeddings,
            decoder.rope_theta,
            decoder.rms_norm_eps,
        )
        .with_attention_bias(decoder.attention_bias);

        let embed_tokens = candle_nn::embedding(
            decoder.vocab_size,
            decoder.hidden_size,
            vb.pp("embed_tokens"),
        )?;

        let mut layers = Vec::with_capacity(layer_count);
        for index in 0..layer_count {
            layers.push(TransformerBlock::load(
                &gqa_config,
                decoder.intermediate_size,
                vb.pp(format!("layers.{index}")),
            )?);
        }

        let norm = if apply_final_norm {
            Some(RmsNorm::load(
                decoder.hidden_size,
                decoder.rms_norm_eps,
                vb.pp("norm"),
            )?)
        } else {
            None
        };
        let (rope_cos, rope_sin) = precompute_rope_freqs(
            head_dim,
            decoder.max_position_embeddings,
            decoder.rope_theta,
            device,
            dtype,
        )?;

        Ok(Self {
            embed_tokens,
            layers,
            norm,
            rope_cos,
            rope_sin,
        })
    }

    pub fn embed(&self, token_ids: &Tensor) -> CandleResult<Tensor> {
        self.embed_tokens.forward(token_ids)
    }

    pub fn load_cache_state(&mut self, cache_state: &[LayerKvCache]) -> Result<(), TtsError> {
        if cache_state.len() != self.layers.len() {
            return Err(TtsError::ModelError(format!(
                "VibeVoice Realtime cache state has {} layer(s), expected {}",
                cache_state.len(),
                self.layers.len(),
            )));
        }

        for (layer, cache_entry) in self.layers.iter_mut().zip(cache_state.iter()) {
            layer.set_cache_state(cache_entry.clone());
        }
        Ok(())
    }

    pub fn decode_step(
        &mut self,
        input_embedding: &Tensor,
        start_pos: usize,
    ) -> Result<RealtimeDecoderState, TtsError> {
        let input_embeds = match input_embedding.rank() {
            1 => input_embedding.unsqueeze(0)?.unsqueeze(0)?,
            2 => input_embedding.unsqueeze(0)?,
            3 => input_embedding.clone(),
            _ => {
                return Err(TtsError::ModelError(
                    "Unexpected VibeVoice Realtime embedding rank while decoding incrementally"
                        .to_string(),
                ));
            }
        };

        let mut hidden = input_embeds;
        for layer in &mut self.layers {
            hidden = layer.forward(&hidden, &self.rope_cos, &self.rope_sin, start_pos, None)?;
        }
        if let Some(norm) = &self.norm {
            hidden = norm.forward(&hidden)?;
        }
        let last_hidden = hidden.narrow(1, hidden.dim(1)? - 1, 1)?.squeeze(1)?;

        Ok(RealtimeDecoderState::new(
            start_pos + 1,
            last_hidden,
            self.layers
                .iter()
                .map(TransformerBlock::cache_state)
                .collect(),
        ))
    }
}

pub struct BinaryClassifier {
    fc1: Linear,
    fc2: Linear,
}

impl BinaryClassifier {
    pub fn load(hidden_size: usize, vb: VarBuilder) -> CandleResult<Self> {
        let fc1 = candle_nn::linear(hidden_size, hidden_size, vb.pp("fc1"))?;
        let fc2 = candle_nn::linear(hidden_size, 1, vb.pp("fc2"))?;
        Ok(Self { fc1, fc2 })
    }

    pub fn forward(&self, hidden: &Tensor) -> CandleResult<Tensor> {
        let hidden = self.fc1.forward(hidden)?.relu()?;
        self.fc2.forward(&hidden)
    }
}