use candle_core::{DType, Device, Result as CandleResult, Tensor};
use candle_nn::{Embedding, Linear, Module, VarBuilder};
use crate::error::TtsError;
use crate::layers::attention::GqaConfig;
use crate::layers::transformer::TransformerBlock;
use crate::tensor_utils::{precompute_rope_freqs, RmsNorm};
use super::config::VibeVoiceDecoderConfig;
use super::generation::{DecoderCacheState, LayerKvCache};
pub(crate) struct SpeechConnector {
fc1: Linear,
norm: RmsNorm,
fc2: Linear,
}
impl SpeechConnector {
pub(crate) fn load(input_dim: usize, output_dim: usize, vb: VarBuilder) -> CandleResult<Self> {
let fc1 = candle_nn::linear(input_dim, output_dim, vb.pp("fc1"))?;
let norm = RmsNorm::load(output_dim, 1e-6, vb.pp("norm"))?;
let fc2 = candle_nn::linear(output_dim, output_dim, vb.pp("fc2"))?;
Ok(Self { fc1, norm, fc2 })
}
pub(crate) fn forward(&self, features: &Tensor) -> CandleResult<Tensor> {
let original_dims = features.dims().to_vec();
let input_dim = *original_dims.last().unwrap_or(&0);
let features_2d = connector_input(features, &original_dims, input_dim)?;
let hidden = self.fc1.forward(&features_2d)?;
let hidden = self.norm.forward(&hidden)?;
let hidden = self.fc2.forward(&hidden)?;
reshape_connector_output(hidden, &original_dims)
}
}
pub(crate) struct VibeVoiceLanguageModel {
embed_tokens: Embedding,
layers: Vec<TransformerBlock>,
norm: RmsNorm,
rope_cos: Tensor,
rope_sin: Tensor,
dtype: DType,
}
impl VibeVoiceLanguageModel {
pub(crate) fn load(
config: &VibeVoiceDecoderConfig,
vb: VarBuilder,
device: &Device,
dtype: DType,
) -> Result<Self, TtsError> {
let head_dim = config.hidden_size / config.num_attention_heads;
let gqa_config = GqaConfig::with_head_dim(
config.hidden_size,
config.num_attention_heads,
config.num_key_value_heads,
head_dim,
config.max_position_embeddings,
config.rope_theta,
config.rms_norm_eps,
)
.with_attention_bias(config.attention_bias);
let embed_tokens =
candle_nn::embedding(config.vocab_size, config.hidden_size, vb.pp("embed_tokens"))?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for index in 0..config.num_hidden_layers {
layers.push(TransformerBlock::load(
&gqa_config,
config.intermediate_size,
vb.pp(format!("layers.{}", index)),
)?);
}
let norm = RmsNorm::load(config.hidden_size, config.rms_norm_eps, vb.pp("norm"))?;
let (rope_cos, rope_sin) = precompute_rope_freqs(
head_dim,
config.max_position_embeddings,
config.rope_theta,
device,
dtype,
)?;
Ok(Self {
embed_tokens,
layers,
norm,
rope_cos,
rope_sin,
dtype,
})
}
pub(crate) fn embed(&self, token_ids: &Tensor) -> CandleResult<Tensor> {
self.embed_tokens.forward(token_ids)
}
pub(crate) fn prefill(&mut self, input_embeds: &Tensor) -> Result<DecoderCacheState, TtsError> {
let (_batch, seq_len, _hidden) = input_embeds.dims3()?;
self.clear_cache();
let mask = causal_mask(seq_len, input_embeds.device(), self.dtype)?;
let hidden = self.forward_hidden(input_embeds, 0, mask.as_ref())?;
self.capture_decode_state(hidden, seq_len)
}
pub(crate) fn decode_step(
&mut self,
input_embedding: &Tensor,
start_pos: usize,
) -> Result<DecoderCacheState, TtsError> {
let input_embeds = step_input_embeddings(input_embedding)?;
let hidden = self.forward_hidden(&input_embeds, start_pos, None)?;
self.capture_decode_state(hidden, start_pos + 1)
}
pub(crate) fn load_cache_state(
&mut self,
cache_state: &[LayerKvCache],
) -> Result<(), TtsError> {
if cache_state.len() != self.layers.len() {
return Err(TtsError::ModelError(format!(
"VibeVoice cache state has {} layer(s), expected {}",
cache_state.len(),
self.layers.len(),
)));
}
for (layer, cache_entry) in self.layers.iter_mut().zip(cache_state.iter()) {
layer.set_cache_state(cache_entry.clone());
}
Ok(())
}
pub(crate) fn clear_cache(&mut self) {
for layer in &mut self.layers {
layer.clear_cache();
}
}
fn capture_decode_state(
&self,
hidden: Tensor,
next_position: usize,
) -> Result<DecoderCacheState, TtsError> {
let last_hidden = hidden.narrow(1, hidden.dim(1)? - 1, 1)?.squeeze(1)?;
let logits = self.next_logits(&last_hidden)?;
Ok(DecoderCacheState::new(
next_position,
last_hidden,
logits,
self.capture_cache_state(),
))
}
fn capture_cache_state(&self) -> Vec<LayerKvCache> {
self.layers
.iter()
.map(TransformerBlock::cache_state)
.collect()
}
fn forward_hidden(
&mut self,
input_embeds: &Tensor,
start_pos: usize,
mask: Option<&Tensor>,
) -> CandleResult<Tensor> {
let mut hidden = input_embeds.clone();
for layer in &mut self.layers {
hidden = layer.forward(&hidden, &self.rope_cos, &self.rope_sin, start_pos, mask)?;
}
self.norm.forward(&hidden)
}
fn next_logits(&self, last_hidden: &Tensor) -> CandleResult<Tensor> {
let weight = self.embed_tokens.embeddings().transpose(0, 1)?;
last_hidden.matmul(&weight)
}
}
fn flattened_leading_dims(dims: &[usize]) -> usize {
if dims.len() <= 1 {
return 1;
}
dims[..dims.len() - 1].iter().product::<usize>()
}
fn connector_input(
features: &Tensor,
original_dims: &[usize],
input_dim: usize,
) -> CandleResult<Tensor> {
if original_dims.len() == 2 {
return Ok(features.clone());
}
let leading = flattened_leading_dims(original_dims);
features.reshape((leading, input_dim))
}
fn connector_output_dims(original_dims: &[usize], hidden_dim: usize) -> Option<Vec<usize>> {
if original_dims.len() == 2 {
return None;
}
let mut output_dims = original_dims.to_vec();
if let Some(last) = output_dims.last_mut() {
*last = hidden_dim;
}
Some(output_dims)
}
fn reshape_connector_output(hidden: Tensor, original_dims: &[usize]) -> CandleResult<Tensor> {
let Some(output_dims) =
connector_output_dims(original_dims, hidden.dim(candle_core::D::Minus1)?)
else {
return Ok(hidden);
};
hidden.reshape(output_dims)
}
fn causal_mask(seq_len: usize, device: &Device, dtype: DType) -> CandleResult<Option<Tensor>> {
if seq_len <= 1 {
return Ok(None);
}
let mut mask_data = vec![f32::NEG_INFINITY; seq_len * seq_len];
for row in 0..seq_len {
for col in 0..=row {
mask_data[row * seq_len + col] = 0.0;
}
}
let mask = Tensor::from_vec(mask_data, (seq_len, seq_len), device)?
.to_dtype(dtype)?
.unsqueeze(0)?
.unsqueeze(0)?;
Ok(Some(mask))
}
fn step_input_embeddings(input_embedding: &Tensor) -> Result<Tensor, TtsError> {
match input_embedding.rank() {
1 => input_embedding
.unsqueeze(0)?
.unsqueeze(0)
.map_err(Into::into),
2 => input_embedding.unsqueeze(0).map_err(Into::into),
3 => Ok(input_embedding.clone()),
_ => Err(TtsError::ModelError(
"Unexpected VibeVoice embedding rank while decoding incrementally".to_string(),
)),
}
}