use candle_core::{Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder, linear, linear_no_bias};
use crate::nn::attention::config::AttentionConfig;
use crate::nn::attention::rope_strategy::RopeStrategy;
use crate::nn::attention::{eager_attention_forward_with_sliding_window, flash_attention_forward};
use crate::nn::kv_cache::KVCache;
use crate::nn::norm::RMSNorm;
#[derive(Debug, Clone)]
pub struct UnifiedAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: RMSNorm,
k_norm: RMSNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
num_kv_groups: usize,
scaling: f64,
rope_strategy: RopeStrategy,
sliding_window: Option<usize>,
layer_idx: usize,
use_flash_attn: bool,
}
impl UnifiedAttention {
pub fn new<C: AttentionConfig>(
config: &C,
rope_strategy: RopeStrategy,
layer_idx: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let head_dim = config.head_dim();
let num_heads = config.num_attention_heads();
let num_kv_heads = config.num_key_value_heads();
let hidden_size = config.hidden_size();
let q_proj = if config.attention_bias() {
linear(hidden_size, num_heads * head_dim, vb.pp("q_proj"))?
} else {
linear_no_bias(hidden_size, num_heads * head_dim, vb.pp("q_proj"))?
};
let k_proj = if config.attention_bias() {
linear(hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?
} else {
linear_no_bias(hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?
};
let v_proj = if config.attention_bias() {
linear(hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?
} else {
linear_no_bias(hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?
};
let o_proj = if config.attention_bias() {
linear(num_heads * head_dim, hidden_size, vb.pp("o_proj"))?
} else {
linear_no_bias(num_heads * head_dim, hidden_size, vb.pp("o_proj"))?
};
let q_norm = RMSNorm::new(head_dim, config.rms_norm_eps(), vb.pp("q_norm"))?;
let k_norm = RMSNorm::new(head_dim, config.rms_norm_eps(), vb.pp("k_norm"))?;
let num_kv_groups = num_heads / num_kv_heads;
let scaling = (head_dim as f64).powf(-0.5);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
head_dim,
num_kv_groups,
scaling,
rope_strategy,
sliding_window: config.sliding_window(),
layer_idx,
use_flash_attn,
})
}
pub fn layer_idx(&self) -> usize {
self.layer_idx
}
pub fn sliding_window(&self) -> Option<usize> {
self.sliding_window
}
pub fn forward(
&self,
hidden_states: &Tensor,
position_embeddings: (&Tensor, &Tensor),
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (batch, seq_len, _) = hidden_states.dims3()?;
let query_states = self.q_proj.forward(hidden_states)?;
let key_states = self.k_proj.forward(hidden_states)?;
let value_states = self.v_proj.forward(hidden_states)?;
let query_states = query_states.reshape((batch, seq_len, self.num_heads, self.head_dim))?;
let key_states = key_states.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?;
let value_states =
value_states.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?;
let query_states = self.q_norm.forward(&query_states)?;
let key_states = self.k_norm.forward(&key_states)?;
let query_states = query_states.transpose(1, 2)?;
let key_states = key_states.transpose(1, 2)?;
let value_states = value_states.transpose(1, 2)?;
let (cos, sin) = position_embeddings;
let (query_states, key_states) =
self.rope_strategy
.apply(&query_states, &key_states, cos, sin)?;
let use_flash = self.use_flash_attn && query_states.device().is_cuda();
let attn_output = if use_flash {
flash_attention_forward(
&query_states,
&key_states,
&value_states,
self.scaling,
self.sliding_window,
)?
} else {
eager_attention_forward_with_sliding_window(
&query_states,
&key_states,
&value_states,
attention_mask,
self.num_kv_groups,
self.scaling,
self.sliding_window,
)?
};
let attn_output = attn_output.reshape((batch, seq_len, self.num_heads * self.head_dim))?;
self.o_proj.forward(&attn_output)
}
pub fn forward_with_cache(
&self,
hidden_states: &Tensor,
position_embeddings: (&Tensor, &Tensor),
attention_mask: Option<&Tensor>,
cache: &mut KVCache,
) -> Result<Tensor> {
let (batch, seq_len, _) = hidden_states.dims3()?;
let query_states = self.q_proj.forward(hidden_states)?;
let key_states = self.k_proj.forward(hidden_states)?;
let value_states = self.v_proj.forward(hidden_states)?;
let query_states = query_states.reshape((batch, seq_len, self.num_heads, self.head_dim))?;
let key_states = key_states.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?;
let value_states =
value_states.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?;
let query_states = self.q_norm.forward(&query_states)?;
let key_states = self.k_norm.forward(&key_states)?;
let query_states = query_states.transpose(1, 2)?;
let key_states = key_states.transpose(1, 2)?;
let value_states = value_states.transpose(1, 2)?;
let (cos, sin) = position_embeddings;
let (query_states, key_states) =
self.rope_strategy
.apply(&query_states, &key_states, cos, sin)?;
let (key_states, value_states) =
cache.update(self.layer_idx, &key_states, &value_states)?;
let use_flash = self.use_flash_attn && query_states.device().is_cuda();
let attn_output = if use_flash {
flash_attention_forward(
&query_states,
&key_states,
&value_states,
self.scaling,
self.sliding_window,
)?
} else {
eager_attention_forward_with_sliding_window(
&query_states,
&key_states,
&value_states,
attention_mask,
self.num_kv_groups,
self.scaling,
self.sliding_window,
)?
};
let attn_output = attn_output.reshape((batch, seq_len, self.num_heads * self.head_dim))?;
self.o_proj.forward(&attn_output)
}
}