use candle_core::{IndexOp, Result, Tensor};
use candle_nn::{Embedding, Linear, Module, VarBuilder, embedding, linear_no_bias};
use crate::{
config::talker_config::TalkerConfig,
nn::{
code_predictor::TalkerCodePredictorForConditionalGeneration,
decoder_layer::TalkerDecoderLayer, kv_cache::KVCache, mlp::TalkerResizeMLP, norm::RMSNorm,
rope::talker::TalkerRotaryEmbedding,
},
};
#[derive(Debug, Clone)]
pub struct TalkerModel {
layers: Vec<TalkerDecoderLayer>,
norm: RMSNorm,
rotary_emb: TalkerRotaryEmbedding,
codec_embedding: Embedding,
text_embedding: Embedding,
}
impl TalkerModel {
pub fn new(config: &TalkerConfig, use_flash_attn: bool, vb: VarBuilder) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|i| {
TalkerDecoderLayer::new(config, i, use_flash_attn, vb.pp(format!("layers.{}", i)))
})
.collect::<Result<Vec<_>>>()?;
let norm = RMSNorm::new(config.hidden_size, config.rms_norm_eps, vb.pp("norm"))?;
let head_dim = config.head_dim();
let rotary_emb = TalkerRotaryEmbedding::new(
head_dim,
config.max_position_embeddings,
config.rope_theta,
vb.device(),
)?;
let codec_embedding = embedding(
config.vocab_size,
config.hidden_size,
vb.pp("codec_embedding"),
)?;
let text_embedding = embedding(
config.text_vocab_size,
config.text_hidden_size,
vb.pp("text_embedding"),
)?;
Ok(Self {
layers,
norm,
rotary_emb,
codec_embedding,
text_embedding,
})
}
pub fn load(config: &TalkerConfig, use_flash_attn: bool, vb: VarBuilder) -> Result<Self> {
Self::new(config, use_flash_attn, vb)
}
pub fn get_codec_embedding(&self) -> &Embedding {
&self.codec_embedding
}
pub fn get_text_embedding(&self) -> &Embedding {
&self.text_embedding
}
pub fn forward(
&self,
inputs_embeds: &Tensor,
position_ids: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (cos, sin) = self.rotary_emb.forward(inputs_embeds, position_ids)?;
let mut hidden_states = inputs_embeds.clone();
for (layer_idx, layer) in self.layers.iter().enumerate() {
hidden_states = layer.forward(&hidden_states, (&cos, &sin), attention_mask)?;
if tracing::enabled!(tracing::Level::TRACE) {
Self::print_layer_stats(&hidden_states, layer_idx)?;
}
}
let output = self.norm.forward(&hidden_states)?;
if tracing::enabled!(tracing::Level::TRACE) {
Self::print_after_norm_stats(&output)?;
}
Ok(output)
}
pub fn forward_with_cache(
&self,
inputs_embeds: &Tensor,
position_ids: &Tensor,
attention_mask: Option<&Tensor>,
cache: &mut KVCache,
) -> Result<Tensor> {
let (cos, sin) = self.rotary_emb.forward(inputs_embeds, position_ids)?;
let mut hidden_states = inputs_embeds.clone();
for (layer_idx, layer) in self.layers.iter().enumerate() {
hidden_states =
layer.forward_with_cache(&hidden_states, (&cos, &sin), attention_mask, cache)?;
if tracing::enabled!(tracing::Level::TRACE) {
Self::print_layer_stats(&hidden_states, layer_idx)?;
}
}
let output = self.norm.forward(&hidden_states)?;
if tracing::enabled!(tracing::Level::TRACE) {
Self::print_after_norm_stats(&output)?;
}
Ok(output)
}
fn print_layer_stats(hidden_states: &Tensor, layer_idx: usize) -> Result<()> {
let hs_f32 = hidden_states.to_dtype(candle_core::DType::F32)?;
let mean_val = hs_f32.mean_all()?.to_scalar::<f32>()?;
let min_val = hs_f32
.min(candle_core::D::Minus1)?
.min(candle_core::D::Minus1)?
.to_vec1::<f32>()?[0];
let max_val = hs_f32
.max(candle_core::D::Minus1)?
.max(candle_core::D::Minus1)?
.to_vec1::<f32>()?[0];
let seq_len = hidden_states.dim(1)?;
let last_pos = hs_f32.i((.., seq_len - 1, ..5))?.to_vec2::<f32>()?;
let first_vals: Vec<String> = last_pos[0].iter().map(|v| format!("{:.4}", v)).collect();
let variance = hs_f32.var(candle_core::D::Minus1)?;
let variance_mean = variance.mean_all()?.to_scalar::<f32>()?;
let std_val = variance_mean.sqrt();
tracing::trace!(
layer = layer_idx,
mean = format!("{:.6}", mean_val),
std = format!("{:.4}", std_val),
min = format!("{:.4}", min_val),
max = format!("{:.4}", max_val),
first_values = %first_vals.join(", "),
"Layer stats"
);
Ok(())
}
fn print_after_norm_stats(hidden_states: &Tensor) -> Result<()> {
let hs_f32 = hidden_states.to_dtype(candle_core::DType::F32)?;
let mean_val = hs_f32.mean_all()?.to_scalar::<f32>()?;
let min_val = hs_f32
.min(candle_core::D::Minus1)?
.min(candle_core::D::Minus1)?
.to_vec1::<f32>()?[0];
let max_val = hs_f32
.max(candle_core::D::Minus1)?
.max(candle_core::D::Minus1)?
.to_vec1::<f32>()?[0];
let seq_len = hidden_states.dim(1)?;
let last_pos = hs_f32.i((.., seq_len - 1, ..5))?.to_vec2::<f32>()?;
let first_vals: Vec<String> = last_pos[0].iter().map(|v| format!("{:.4}", v)).collect();
let variance = hs_f32.var(candle_core::D::Minus1)?;
let variance_mean = variance.mean_all()?.to_scalar::<f32>()?;
let std_val = variance_mean.sqrt();
tracing::trace!(
mean = format!("{:.6}", mean_val),
std = format!("{:.4}", std_val),
min = format!("{:.4}", min_val),
max = format!("{:.4}", max_val),
first_values = %first_vals.join(", "),
"After norm stats"
);
Ok(())
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
#[derive(Debug, Clone)]
pub struct TalkerForConditionalGeneration {
model: TalkerModel,
lm_head: Linear,
text_projection: TalkerResizeMLP,
code_predictor: TalkerCodePredictorForConditionalGeneration,
codec_eos_token_id: usize,
codec_bos_id: usize,
codec_pad_id: usize,
codec_think_id: usize,
codec_nothink_id: usize,
codec_think_bos_id: usize,
codec_think_eos_id: usize,
num_code_groups: usize,
hidden_size: usize,
}
impl TalkerForConditionalGeneration {
pub fn new(config: &TalkerConfig, use_flash_attn: bool, vb: VarBuilder) -> Result<Self> {
let model = TalkerModel::new(config, use_flash_attn, vb.pp("model"))?;
let lm_head = linear_no_bias(
config.hidden_size,
config.vocab_size,
vb.pp("codec_head"), )?;
let text_projection = TalkerResizeMLP::new(
config.text_hidden_size,
config.text_hidden_size,
config.hidden_size,
&config.hidden_act,
true, vb.pp("text_projection"),
)?;
let code_predictor = TalkerCodePredictorForConditionalGeneration::new(
&config.code_predictor_config,
config,
use_flash_attn,
vb.pp("code_predictor"),
)?;
Ok(Self {
model,
lm_head,
text_projection,
code_predictor,
codec_eos_token_id: config.codec_eos_token_id,
codec_bos_id: config.codec_bos_id,
codec_pad_id: config.codec_pad_id,
codec_think_id: config.codec_think_id,
codec_nothink_id: config.codec_nothink_id,
codec_think_bos_id: config.codec_think_bos_id,
codec_think_eos_id: config.codec_think_eos_id,
num_code_groups: config.num_code_groups,
hidden_size: config.hidden_size,
})
}
pub fn load(config: &TalkerConfig, use_flash_attn: bool, vb: VarBuilder) -> Result<Self> {
Self::new(config, use_flash_attn, vb)
}
pub fn get_model(&self) -> &TalkerModel {
&self.model
}
pub fn get_code_predictor(&self) -> &TalkerCodePredictorForConditionalGeneration {
&self.code_predictor
}
pub fn get_special_tokens(&self) -> (usize, usize, usize) {
(
self.codec_eos_token_id,
self.codec_bos_id,
self.codec_pad_id,
)
}
pub fn get_think_tokens(&self) -> (usize, usize, usize, usize) {
(
self.codec_think_id,
self.codec_nothink_id,
self.codec_think_bos_id,
self.codec_think_eos_id,
)
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn project_text_embeds(&self, text_embeds: &Tensor) -> Result<Tensor> {
self.text_projection.forward(text_embeds)
}
pub fn embed_and_project_text(&self, text_ids: &Tensor) -> Result<Tensor> {
let text_embeds = self.model.text_embedding.forward(text_ids)?;
self.text_projection.forward(&text_embeds)
}
pub fn get_codec_embedding(
&self,
token_id: usize,
device: &candle_core::Device,
) -> Result<Tensor> {
let id_tensor = Tensor::new(&[token_id as u32], device)?;
self.model.codec_embedding.forward(&id_tensor)
}
pub fn forward(
&self,
inputs_embeds: &Tensor,
position_ids: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let hidden_states = self
.model
.forward(inputs_embeds, position_ids, attention_mask)?;
let logits = self.lm_head.forward(&hidden_states)?;
Ok((logits, hidden_states))
}
pub fn forward_with_cache(
&self,
inputs_embeds: &Tensor,
position_ids: &Tensor,
attention_mask: Option<&Tensor>,
cache: &mut KVCache,
) -> Result<(Tensor, Tensor)> {
let hidden_states =
self.model
.forward_with_cache(inputs_embeds, position_ids, attention_mask, cache)?;
let logits = self.lm_head.forward(&hidden_states)?;
Ok((logits, hidden_states))
}
pub fn num_layers(&self) -> usize {
self.model.num_layers()
}
pub fn generate_step(
&self,
inputs_embeds: &Tensor,
position_ids: &Tensor,
attention_mask: Option<&Tensor>,
sampling_config: &super::sampling::SamplingConfig,
) -> Result<Tensor> {
let (logits, hidden_states) = self.forward(inputs_embeds, position_ids, attention_mask)?;
let last_logits = logits.i((.., logits.dim(1)? - 1, ..))?;
let code_0 = last_logits.argmax(candle_core::D::Minus1)?;
let last_hidden = hidden_states.i((.., hidden_states.dim(1)? - 1, ..))?;
let last_hidden = last_hidden.unsqueeze(1)?;
let subtalker_config = sampling_config.for_subtalker();
let remaining_codes =
self.code_predictor
.generate_with_cache(&last_hidden, None, &subtalker_config)?;
let code_0 = code_0.unsqueeze(1)?;
Tensor::cat(&[&code_0, &remaining_codes], 1)
}
pub fn embed_code(&self, code: &Tensor) -> Result<Tensor> {
self.model.codec_embedding.forward(code)
}
pub fn sum_code_embeddings(&self, all_codes: &Tensor) -> Result<Tensor> {
let num_codebooks = all_codes.dim(1)?;
let code_0 = all_codes.i((.., 0))?;
let mut embed_sum = self.model.codec_embedding.forward(&code_0)?;
for i in 1..num_codebooks.min(self.num_code_groups) {
if let Some(emb_layer) = self.code_predictor.get_input_embedding(i) {
let code_i = all_codes.i((.., i))?;
let embed_i = emb_layer.forward(&code_i)?;
embed_sum = (embed_sum + embed_i)?;
}
}
Ok(embed_sum)
}
}