use ferrum_kernels::backend::{
Backend, BackendGraph, BackendMoeFused, BackendPagedKv, BackendQuantGguf, BackendQuantMarlin,
LlmBackend, MoeLlmBackend,
};
use ferrum_quantization::loader::WeightLoader;
use ferrum_quantization::PrefixedLoader;
use ferrum_types::Result;
use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyModel};
use crate::multimodal::qwen3_tts::TalkerConfig;
pub trait TalkerBackboneForward: Send + Sync {
fn forward(&mut self, input_f32: &[f32], seq_len: usize) -> Vec<f32>;
fn reset(&mut self);
}
pub struct TalkerBackboneBackend<B: MoeLlmBackend> {
backbone: LlamaFamilyModel<B>,
cache_id: String,
pos: usize,
}
impl<B: MoeLlmBackend> TalkerBackboneBackend<B> {
pub fn new(cfg: &TalkerConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
let backbone_cfg = LlamaFamilyConfig {
hidden_size: cfg.hidden_size,
intermediate_size: cfg.intermediate_size,
num_heads: cfg.num_attention_heads,
num_kv_heads: cfg.num_key_value_heads,
head_dim: cfg.head_dim,
num_layers: cfg.num_hidden_layers,
vocab_size: cfg.vocab_size,
max_seq_len: cfg.max_position_embeddings,
rms_norm_eps: cfg.rms_norm_eps as f32,
rope_theta: cfg.rope_theta,
rope_scaling: None,
rope_interleaved: false,
has_qk_norm: true,
sliding_window: 0,
};
let prefixed = PrefixedLoader::new(loader, "talker.");
let backbone = LlamaFamilyModel::<B>::new_backbone_only(backbone_cfg, &prefixed)?;
Ok(Self {
backbone,
cache_id: "tts-talker".to_string(),
pos: 0,
})
}
pub fn new_code_predictor(cfg: &TalkerConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
let cp_h = cfg.code_predictor_hidden_size;
let backbone_cfg = LlamaFamilyConfig {
hidden_size: cp_h,
intermediate_size: cp_h * 3,
num_heads: cfg.code_predictor_num_heads,
num_kv_heads: cfg.code_predictor_num_kv_heads,
head_dim: cfg.code_predictor_head_dim,
num_layers: cfg.code_predictor_num_layers,
vocab_size: cfg.code_predictor_vocab_size,
max_seq_len: cfg.max_position_embeddings,
rms_norm_eps: cfg.rms_norm_eps as f32,
rope_theta: cfg.rope_theta,
rope_scaling: None,
rope_interleaved: false,
has_qk_norm: true,
sliding_window: 0,
};
let prefixed = PrefixedLoader::new(loader, "talker.code_predictor.");
let backbone = LlamaFamilyModel::<B>::new_backbone_only(backbone_cfg, &prefixed)?;
Ok(Self {
backbone,
cache_id: "tts-subtalker".to_string(),
pos: 0,
})
}
}
impl<B: MoeLlmBackend> TalkerBackboneForward for TalkerBackboneBackend<B> {
fn forward(&mut self, input_f32: &[f32], seq_len: usize) -> Vec<f32> {
let h = self.backbone.cfg.hidden_size;
assert_eq!(
input_f32.len(),
seq_len * h,
"TalkerBackboneBackend: input len {} != seq_len * hidden {}",
input_f32.len(),
seq_len * h
);
tracing::debug!(
"TalkerBackboneBackend::forward cache={} seq_len={} pos_offset={}",
self.cache_id,
seq_len,
self.pos
);
let out = if seq_len == 1 {
self.backbone
.decode_post_norm_from_embed(&self.cache_id, input_f32, self.pos as u32)
} else {
self.backbone
.prefill_all_post_norm(&self.cache_id, input_f32, seq_len, self.pos)
};
self.pos += seq_len;
out
}
fn reset(&mut self) {
self.pos = 0;
self.backbone.kv_caches.remove(&self.cache_id);
}
}