use anyhow::Result;
use candle_core::Tensor;
use candle_nn::{
Embedding, Linear, Module, RmsNorm, VarBuilder, embedding, linear_no_bias, rms_norm,
};
use crate::{
models::{
common::{
InferenceModel,
modules::{GateUpDownMLP, QKNormAttention},
},
qwen3::config::Qwen3Config,
},
position_embed::rope::RoPE,
utils::tensor_utils::prepare_causal_attention_mask,
};
pub struct Qwen3DecoderLayer {
self_attn: QKNormAttention,
mlp: GateUpDownMLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl Qwen3DecoderLayer {
pub fn new(config: &Qwen3Config, vb: VarBuilder) -> Result<Self> {
let self_attn = QKNormAttention::new(
vb.pp("self_attn"),
config.hidden_size,
config.num_attention_heads,
Some(config.head_dim),
Some(config.num_key_value_heads),
config.attention_bias,
config.rms_norm_eps,
None,
None,
None,
None,
None,
None,
)?;
let mlp = GateUpDownMLP::new(
vb.pp("mlp"),
config.hidden_size,
config.intermediate_size,
config.hidden_act,
false,
None,
None,
None,
)?;
let input_layernorm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
vb.pp("input_layernorm"),
)?;
let post_attention_layernorm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs.clone();
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, cos, sin, attention_mask)?;
let xs = residual.add(&xs)?;
let residual = xs.clone();
let xs = self.post_attention_layernorm.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
let xs = residual.add(&xs)?;
Ok(xs)
}
pub fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
pub struct Qwen3Model {
embed_tokens: Embedding,
layers: Vec<Qwen3DecoderLayer>,
norm: RmsNorm,
rotary_emb: RoPE,
lm_head: Linear,
stop_token_ids: Vec<u32>,
}
impl Qwen3Model {
pub fn new(config: &Qwen3Config, vb: VarBuilder, eos_ids: Vec<u32>) -> Result<Self> {
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
vb.pp("model")
} else {
vb
};
let vocab_size = config.vocab_size;
let embed_tokens = embedding(vocab_size, config.hidden_size, vb.pp("embed_tokens"))?;
let mut layers = vec![];
let vb_l = vb.pp("layers");
for layer_idx in 0..config.num_hidden_layers {
let layer = Qwen3DecoderLayer::new(config, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("norm"))?;
let head_dim = config.head_dim;
let rotary_emb = RoPE::new(head_dim, config.rope_theta, vb.device())?;
let lm_head = if config.tie_word_embeddings {
Linear::new(embed_tokens.embeddings().clone(), None)
} else {
linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?
};
Ok(Self {
embed_tokens,
layers,
norm,
rotary_emb,
lm_head,
stop_token_ids: eos_ids,
})
}
pub fn forward(
&mut self,
input_ids: Option<&Tensor>,
inputs_embeds: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let hidden_state = self.forward_hidden(input_ids, inputs_embeds, seqlen_offset)?;
let logits = self.lm_head.forward(&hidden_state)?;
Ok(logits)
}
pub fn forward_hidden(
&mut self,
input_ids: Option<&Tensor>,
inputs_embeds: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
if input_ids.is_none() && inputs_embeds.is_none() {
return Err(anyhow::anyhow!(
"You must specify exactly one of input_ids or inputs_embeds"
));
}
let inputs_embeds = if let Some(inputs_embeds) = inputs_embeds {
inputs_embeds.clone()
} else {
let input_ids = input_ids.unwrap();
self.embedding_token_id(input_ids)?
};
let (bs, seq_len, _) = inputs_embeds.dims3()?;
let attention_mask: Option<Tensor> = {
if seq_len <= 1 {
None
} else {
Some(prepare_causal_attention_mask(
bs,
seq_len,
0,
inputs_embeds.device(),
)?)
}
};
let (cos, sin) = self
.rotary_emb
.forward(seqlen_offset, seq_len, inputs_embeds.device())?;
let mut hidden_states = inputs_embeds;
for decode_layer in &mut self.layers {
hidden_states =
decode_layer.forward(&hidden_states, &cos, &sin, attention_mask.as_ref())?;
}
hidden_states = self.norm.forward(&hidden_states)?;
let hidden_state = hidden_states.narrow(1, seq_len - 1, 1)?;
Ok(hidden_state)
}
pub fn embedding_token_id(&self, input_ids: &Tensor) -> Result<Tensor> {
Ok(self.embed_tokens.forward(input_ids)?)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}
impl InferenceModel for Qwen3Model {
fn forward_step(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
self.forward(input_ids.into(), None, seqlen_offset)
}
fn clear_cache(&mut self) {
self.clear_kv_cache();
}
fn stop_token_ids(&self) -> Vec<u32> {
self.stop_token_ids.clone()
}
}