use ferrum_kernels::backend::{
Backend, BackendGraph, BackendMoeFused, BackendPagedKv, BackendQuantGguf, BackendQuantMarlin,
LlmBackend, MoeLlmBackend,
};
use ferrum_quantization::loader::WeightLoader;
use ferrum_quantization::traits::Linear;
use ferrum_quantization::PrefixedLoader;
use ferrum_types::Result;
use std::collections::HashMap;
use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyModel};
use crate::multimodal::qwen3_tts::TalkerConfig;
pub struct Qwen3TtsTalker<B: MoeLlmBackend> {
pub cfg: TalkerConfig,
pub backbone: LlamaFamilyModel<B>,
pub text_embedding: B::Buffer,
pub text_proj_fc1: Box<dyn Linear<B>>,
pub text_proj_fc2: Box<dyn Linear<B>>,
pub codec_embedding: B::Buffer,
pub codec_head: Box<dyn Linear<B>>,
positions: HashMap<String, u32>,
}
impl<B: MoeLlmBackend> Qwen3TtsTalker<B> {
pub fn new(cfg: TalkerConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
let backbone_cfg = LlamaFamilyConfig {
hidden_size: cfg.hidden_size,
intermediate_size: cfg.intermediate_size,
num_heads: cfg.num_attention_heads,
num_kv_heads: cfg.num_key_value_heads,
head_dim: cfg.head_dim,
num_layers: cfg.num_hidden_layers,
vocab_size: cfg.vocab_size,
max_seq_len: cfg.max_position_embeddings,
rms_norm_eps: cfg.rms_norm_eps as f32,
rope_theta: cfg.rope_theta,
rope_scaling: None,
rope_interleaved: false,
has_qk_norm: true,
sliding_window: 0,
};
let talker_loader = PrefixedLoader::new(loader, "talker.");
let backbone = LlamaFamilyModel::<B>::new_backbone_only(backbone_cfg, &talker_loader)?;
let text_embedding = loader.load_tensor("talker.model.text_embedding.weight")?;
let codec_embedding = loader.load_tensor("talker.model.codec_embedding.weight")?;
let text_proj_fc1 = loader.load_linear("talker.text_projection.linear_fc1")?;
let text_proj_fc2 = loader.load_linear("talker.text_projection.linear_fc2")?;
let codec_head = loader.load_linear("talker.codec_head")?;
Ok(Self {
cfg,
backbone,
text_embedding,
text_proj_fc1,
text_proj_fc2,
codec_embedding,
codec_head,
positions: HashMap::new(),
})
}
pub fn reset(&mut self, cache_id: &str) {
self.positions.remove(cache_id);
self.backbone.kv_caches.remove(cache_id);
}
pub fn reset_all(&mut self) {
self.positions.clear();
self.backbone.kv_caches.clear();
}
fn embed_text_token(&mut self, token: u32) -> Vec<f32> {
let text_hidden = self.cfg.text_hidden_size;
let hidden = self.cfg.hidden_size;
let mut ctx = B::new_context();
let mut embed_out = B::alloc(text_hidden);
B::embedding_lookup(
&mut ctx,
&self.text_embedding,
&[token],
&mut embed_out,
text_hidden,
);
let mut fc1_out = B::alloc(text_hidden);
self.text_proj_fc1
.forward(&mut ctx, &embed_out, &mut fc1_out, 1);
B::sync(&mut ctx);
let fc1_host = B::to_vec(&fc1_out, text_hidden);
let silu_host: Vec<f32> = fc1_host
.iter()
.map(|&x| x * (1.0f32 / (1.0f32 + (-x).exp())))
.collect();
let silu_dev = B::from_slice(&silu_host);
let mut fc2_out = B::alloc(hidden);
self.text_proj_fc2
.forward(&mut ctx, &silu_dev, &mut fc2_out, 1);
B::sync(&mut ctx);
B::to_vec(&fc2_out, hidden)
}
fn embed_codec_token(&mut self, token: u32) -> Vec<f32> {
let hidden = self.cfg.hidden_size;
let mut ctx = B::new_context();
let mut out = B::alloc(hidden);
B::embedding_lookup(&mut ctx, &self.codec_embedding, &[token], &mut out, hidden);
B::sync(&mut ctx);
B::to_vec(&out, hidden)
}
pub fn prefill(&mut self, cache_id: &str, tokens: &[(u32, bool)]) -> Vec<f32> {
let h = self.cfg.hidden_size;
let seq_len = tokens.len();
let mut mixed = Vec::with_capacity(seq_len * h);
for (tok, is_text) in tokens {
let emb = if *is_text {
self.embed_text_token(*tok)
} else {
self.embed_codec_token(*tok)
};
mixed.extend(emb);
}
let pre_norm_hidden = self.backbone.prefill_from_embeds(cache_id, &mixed, seq_len);
self.positions.insert(cache_id.to_string(), seq_len as u32);
self.apply_head(&pre_norm_hidden)
}
pub fn decode_codec(&mut self, cache_id: &str, token: u32) -> Vec<f32> {
let pos = *self.positions.get(cache_id).unwrap_or(&0);
let embed = self.embed_codec_token(token);
let pre_norm = self.backbone.decode_from_embed(cache_id, &embed, pos);
self.positions.insert(cache_id.to_string(), pos + 1);
self.apply_head(&pre_norm)
}
fn apply_head(&mut self, hidden_f32: &[f32]) -> Vec<f32> {
let h = self.cfg.hidden_size;
let vocab = self.cfg.vocab_size;
debug_assert_eq!(hidden_f32.len(), h);
let mut ctx = B::new_context();
let hidden_buf = B::from_slice(hidden_f32);
let mut normed = B::alloc(h);
B::rms_norm(
&mut ctx,
&hidden_buf,
&self.backbone.final_norm_w,
self.cfg.rms_norm_eps as f32,
&mut normed,
1,
h,
);
let mut logits = B::alloc(vocab);
self.codec_head.forward(&mut ctx, &normed, &mut logits, 1);
B::sync(&mut ctx);
B::to_vec(&logits, vocab)
}
pub fn last_hidden_normed(&mut self, cache_id: &str) -> Vec<f32> {
let h = self.cfg.hidden_size;
let mut ctx = B::new_context();
let mut normed = B::alloc(h);
B::rms_norm(
&mut ctx,
&self.backbone.scratch.last_hidden,
&self.backbone.final_norm_w,
self.cfg.rms_norm_eps as f32,
&mut normed,
1,
h,
);
B::sync(&mut ctx);
let _ = cache_id; B::to_vec(&normed, h)
}
pub fn codec_embed_lookup(&self, token: u32) -> Vec<f32> {
let mut ctx = B::new_context();
let h = self.cfg.hidden_size;
let mut out = B::alloc(h);
B::embedding_lookup(&mut ctx, &self.codec_embedding, &[token], &mut out, h);
B::sync(&mut ctx);
B::to_vec(&out, h)
}
}
pub struct Qwen3TtsSubTalker<B: MoeLlmBackend> {
pub cfg: TalkerConfig,
pub backbone: LlamaFamilyModel<B>,
pub projection: Option<Box<dyn Linear<B>>>,
pub codec_embeddings: Vec<B::Buffer>,
pub lm_heads: Vec<Box<dyn Linear<B>>>,
}
impl<B: MoeLlmBackend> Qwen3TtsSubTalker<B> {
pub fn new(cfg: TalkerConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
let cp_h = cfg.code_predictor_hidden_size;
let cp_im = cp_h * 3;
let backbone_cfg = LlamaFamilyConfig {
hidden_size: cp_h,
intermediate_size: cp_im,
num_heads: cfg.code_predictor_num_heads,
num_kv_heads: cfg.code_predictor_num_kv_heads,
head_dim: cp_h / cfg.code_predictor_num_heads,
num_layers: cfg.code_predictor_num_layers,
vocab_size: cfg.code_predictor_vocab_size,
max_seq_len: cfg.max_position_embeddings,
rms_norm_eps: cfg.rms_norm_eps as f32,
rope_theta: cfg.rope_theta,
rope_scaling: None,
rope_interleaved: false,
has_qk_norm: true,
sliding_window: 0,
};
let cp_loader = PrefixedLoader::new(loader, "talker.code_predictor.");
let backbone = LlamaFamilyModel::<B>::new_backbone_only(backbone_cfg, &cp_loader)?;
let projection = if cfg.hidden_size != cp_h {
Some(loader.load_linear("talker.code_predictor.small_to_mtp_projection")?)
} else {
None
};
let n_extra = cfg.num_code_groups - 1;
let mut codec_embeddings = Vec::with_capacity(n_extra);
for i in 0..n_extra {
let name = format!("talker.code_predictor.model.codec_embedding.{i}.weight");
codec_embeddings.push(loader.load_tensor(&name)?);
}
let mut lm_heads = Vec::with_capacity(n_extra);
for i in 0..n_extra {
let name = format!("talker.code_predictor.lm_head.{i}");
lm_heads.push(loader.load_linear(&name)?);
}
Ok(Self {
cfg,
backbone,
projection,
codec_embeddings,
lm_heads,
})
}
pub fn predict_greedy(
&mut self,
cache_id: &str,
talker_hidden: &[f32],
first_codec_embed: &[f32],
) -> Vec<u32> {
let h_talker = self.cfg.hidden_size;
let cp_h = self.cfg.code_predictor_hidden_size;
let n_extra = self.cfg.num_code_groups - 1;
debug_assert_eq!(talker_hidden.len(), h_talker);
debug_assert_eq!(first_codec_embed.len(), h_talker);
self.backbone.kv_caches.remove(cache_id);
let mut combined = Vec::with_capacity(2 * h_talker);
combined.extend_from_slice(talker_hidden);
combined.extend_from_slice(first_codec_embed);
let projected: Vec<f32> = if let Some(ref proj) = self.projection {
let mut ctx = B::new_context();
let in_buf = B::from_slice(&combined);
let mut out = B::alloc(2 * cp_h);
proj.forward(&mut ctx, &in_buf, &mut out, 2);
B::sync(&mut ctx);
B::to_vec(&out, 2 * cp_h)
} else {
combined
};
let _ = self.backbone.prefill_from_embeds(cache_id, &projected, 2);
let mut pos: u32 = 2;
let mut predicted = Vec::with_capacity(n_extra);
let vocab = self.cfg.code_predictor_vocab_size;
for i in 0..n_extra {
let mut ctx = B::new_context();
let mut normed = B::alloc(cp_h);
B::rms_norm(
&mut ctx,
&self.backbone.scratch.last_hidden,
&self.backbone.final_norm_w,
self.cfg.rms_norm_eps as f32,
&mut normed,
1,
cp_h,
);
let mut logits = B::alloc(vocab);
self.lm_heads[i].forward(&mut ctx, &normed, &mut logits, 1);
B::sync(&mut ctx);
let logits_host = B::to_vec(&logits, vocab);
let token = logits_host
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
.unwrap_or(0);
predicted.push(token);
if i == n_extra - 1 {
break;
}
let mut ctx2 = B::new_context();
let mut emb = B::alloc(h_talker);
B::embedding_lookup(
&mut ctx2,
&self.codec_embeddings[i],
&[token],
&mut emb,
h_talker,
);
B::sync(&mut ctx2);
let emb_host = B::to_vec(&emb, h_talker);
let next_embed: Vec<f32> = if let Some(ref proj) = self.projection {
let mut ctx3 = B::new_context();
let in_buf = B::from_slice(&emb_host);
let mut out = B::alloc(cp_h);
proj.forward(&mut ctx3, &in_buf, &mut out, 1);
B::sync(&mut ctx3);
B::to_vec(&out, cp_h)
} else {
emb_host
};
let _ = self.backbone.decode_from_embed(cache_id, &next_embed, pos);
pos += 1;
}
predicted
}
}