use anyhow::Result;
use candle_core::Tensor;
use candle_nn::{
Activation, Linear, Module, RmsNorm, VarBuilder, linear, linear_no_bias, rms_norm,
};
use crate::{
models::common::modules::{GateUpDownMLP, eager_attention_forward},
position_embed::rope::{RoPE, apply_rotary_pos_emb},
};
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Qwen2Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
pub sliding_window: usize,
pub max_window_layers: usize,
pub tie_word_embeddings: bool,
pub rope_theta: f32,
pub rms_norm_eps: f64,
pub use_sliding_window: bool,
pub hidden_act: Activation,
}
#[derive(Debug, Clone)]
pub struct Qwen2Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Qwen2Attention {
pub fn new(cfg: &Qwen2Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_size / num_heads;
let q_proj = linear(hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear(hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear(hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size,
kv_cache: None,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
apply_rotary_pos_emb(&query_states, &key_states, cos, sin, false)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_output = eager_attention_forward(
&query_states,
&key_states,
&value_states,
Some(self.num_kv_groups),
attention_mask,
scale,
)?;
let attn_output = attn_output.reshape((b_sz, q_len, self.hidden_size))?;
let attn_output = attn_output.apply(&self.o_proj)?;
Ok(attn_output)
}
pub fn forward_no_cache(
&self,
xs: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
apply_rotary_pos_emb(&query_states, &key_states, cos, sin, false)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_output = eager_attention_forward(
&query_states,
&key_states,
&value_states,
Some(self.num_kv_groups),
attention_mask,
scale,
)?;
let attn_output = attn_output.reshape((b_sz, q_len, self.hidden_size))?;
let attn_output = attn_output.apply(&self.o_proj)?;
Ok(attn_output)
}
pub fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
pub struct Qwen2DecoderLayer {
self_attn: Qwen2Attention,
mlp: GateUpDownMLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl Qwen2DecoderLayer {
pub fn new(cfg: &Qwen2Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Qwen2Attention::new(cfg, vb.pp("self_attn"))?;
let mlp = GateUpDownMLP::new(
vb.pp("mlp"),
cfg.hidden_size,
cfg.intermediate_size,
cfg.hidden_act,
false,
None,
None,
None,
)?;
let input_layernorm =
rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = rms_norm(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, cos, sin, attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
let xs = (residual + xs)?;
Ok(xs)
}
pub fn forward_no_cache(
&self,
xs: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self
.self_attn
.forward_no_cache(&xs, cos, sin, attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
let xs = (residual + xs)?;
Ok(xs)
}
pub fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
pub struct Qwen2Decoder {
layers: Vec<Qwen2DecoderLayer>,
norm: RmsNorm,
rotary_emb: RoPE,
}
impl Qwen2Decoder {
pub fn new(vb: VarBuilder, cfg: &Qwen2Config) -> Result<Self> {
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = Qwen2DecoderLayer::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let rotary_emb = RoPE::new(head_dim, cfg.rope_theta, vb.device())?;
Ok(Self {
layers,
norm,
rotary_emb,
})
}
pub fn forward_no_cache(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let seq_len = xs.dim(1)?;
let (cos, sin) = self
.rotary_emb
.forward(seqlen_offset, seq_len, xs.device())?;
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward_no_cache(&xs, &cos, &sin, attention_mask)?;
}
let xs = xs.apply(&self.norm)?;
Ok(xs)
}
}