use candle_core::{Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use crate::tensor_utils::{apply_rotary_emb, RmsNorm};
#[derive(Debug, Clone)]
pub struct GqaConfig {
pub hidden_size: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub max_position_embeddings: usize,
pub rope_theta: f64,
pub rms_norm_eps: f64,
pub attention_bias: bool,
}
impl GqaConfig {
pub fn new(
hidden_size: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
max_position_embeddings: usize,
rope_theta: f64,
rms_norm_eps: f64,
) -> Self {
let head_dim = hidden_size / num_attention_heads;
Self {
hidden_size,
num_attention_heads,
num_key_value_heads,
head_dim,
max_position_embeddings,
rope_theta,
rms_norm_eps,
attention_bias: false,
}
}
pub fn with_head_dim(
hidden_size: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: usize,
max_position_embeddings: usize,
rope_theta: f64,
rms_norm_eps: f64,
) -> Self {
Self {
hidden_size,
num_attention_heads,
num_key_value_heads,
head_dim,
max_position_embeddings,
rope_theta,
rms_norm_eps,
attention_bias: false,
}
}
pub fn with_attention_bias(mut self, attention_bias: bool) -> Self {
self.attention_bias = attention_bias;
self
}
}
pub struct GroupedQueryAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: Option<RmsNorm>,
k_norm: Option<RmsNorm>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl GroupedQueryAttention {
pub fn load(config: &GqaConfig, vb: VarBuilder) -> Result<Self> {
let q_dim = config.num_attention_heads * config.head_dim;
let kv_dim = config.num_key_value_heads * config.head_dim;
let q_proj = if config.attention_bias {
candle_nn::linear(config.hidden_size, q_dim, vb.pp("q_proj"))?
} else {
candle_nn::linear_no_bias(config.hidden_size, q_dim, vb.pp("q_proj"))?
};
let k_proj = if config.attention_bias {
candle_nn::linear(config.hidden_size, kv_dim, vb.pp("k_proj"))?
} else {
candle_nn::linear_no_bias(config.hidden_size, kv_dim, vb.pp("k_proj"))?
};
let v_proj = if config.attention_bias {
candle_nn::linear(config.hidden_size, kv_dim, vb.pp("v_proj"))?
} else {
candle_nn::linear_no_bias(config.hidden_size, kv_dim, vb.pp("v_proj"))?
};
let o_proj = candle_nn::linear_no_bias(q_dim, config.hidden_size, vb.pp("o_proj"))?;
let q_norm = RmsNorm::load(config.head_dim, config.rms_norm_eps, vb.pp("q_norm")).ok();
let k_norm = RmsNorm::load(config.head_dim, config.rms_norm_eps, vb.pp("k_norm")).ok();
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads: config.num_attention_heads,
num_kv_heads: config.num_key_value_heads,
head_dim: config.head_dim,
kv_cache: None,
})
}
pub fn forward(
&mut self,
x: &Tensor,
cos: &Tensor,
sin: &Tensor,
start_pos: usize,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let (batch, seq_len, _) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((batch, seq_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let q = if let Some(ref qn) = self.q_norm {
qn.forward(&q)?
} else {
q
};
let k = if let Some(ref kn) = self.k_norm {
kn.forward(&k)?
} else {
k
};
let cos_slice = cos
.narrow(0, start_pos, seq_len)?
.unsqueeze(0)?
.unsqueeze(0)?;
let sin_slice = sin
.narrow(0, start_pos, seq_len)?
.unsqueeze(0)?
.unsqueeze(0)?;
let (q, k) = apply_rotary_emb(&q, &k, &cos_slice, &sin_slice)?;
let (k, v) = if let Some((prev_k, prev_v)) = &self.kv_cache {
let k = Tensor::cat(&[prev_k, &k], 2)?;
let v = Tensor::cat(&[prev_v, &v], 2)?;
(k, v)
} else {
(k, v)
};
self.kv_cache = Some((k.clone(), v.clone()));
let repeat_factor = self.num_heads / self.num_kv_heads;
let kv_len = k.dim(2)?;
let k = if repeat_factor > 1 {
k.unsqueeze(2)?
.repeat(&[1, 1, repeat_factor, 1, 1])?
.reshape((batch, self.num_heads, kv_len, self.head_dim))?
} else {
k
};
let v = if repeat_factor > 1 {
v.unsqueeze(2)?
.repeat(&[1, 1, repeat_factor, 1, 1])?
.reshape((batch, self.num_heads, kv_len, self.head_dim))?
} else {
v
};
let scale = (self.head_dim as f64).sqrt();
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? / scale)?;
let attn_weights = if let Some(mask) = mask {
attn_weights.broadcast_add(mask)?
} else {
attn_weights
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output.transpose(1, 2)?.reshape((
batch,
seq_len,
self.num_heads * self.head_dim,
))?;
self.o_proj.forward(&attn_output)
}
pub fn clear_cache(&mut self) {
self.kv_cache = None;
}
pub fn cache_state(&self) -> Option<(Tensor, Tensor)> {
self.kv_cache.clone()
}
pub fn set_cache_state(&mut self, kv_cache: Option<(Tensor, Tensor)>) {
self.kv_cache = kv_cache;
}
}
impl std::fmt::Debug for GroupedQueryAttention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GroupedQueryAttention")
.field("num_heads", &self.num_heads)
.field("num_kv_heads", &self.num_kv_heads)
.field("head_dim", &self.head_dim)
.field("has_cache", &self.kv_cache.is_some())
.finish()
}
}